Skip to content

Commit 463f5d5

Browse files
committed
Ability to cancel a UDF on disconnect or timeout
1 parent 8819d3e commit 463f5d5

File tree

3 files changed

+143
-6
lines changed

3 files changed

+143
-6
lines changed

singlestoredb/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,11 @@
438438
environ=['SINGLESTOREDB_EXT_FUNC_PORT'],
439439
)
440440

441+
register_option(
442+
'external_function.timeout', 'int', check_int, 24*60*60,
443+
'Specifies the timeout in seconds for processing a batch of rows.',
444+
environ=['SINGLESTOREDB_EXT_FUNC_TIMEOUT'],
445+
)
441446

442447
#
443448
# Debugging options

singlestoredb/functions/decorator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _func(
100100
name: Optional[str] = None,
101101
args: Optional[ParameterType] = None,
102102
returns: Optional[ReturnType] = None,
103+
timeout: Optional[int] = None,
103104
) -> Callable[..., Any]:
104105
"""Generic wrapper for UDF and TVF decorators."""
105106

@@ -108,6 +109,7 @@ def _func(
108109
name=name,
109110
args=expand_types(args),
110111
returns=expand_types(returns),
112+
timeout=timeout,
111113
).items() if v is not None
112114
}
113115

singlestoredb/functions/ext/asgi.py

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
"""
2525
import argparse
2626
import asyncio
27+
import contextvars
2728
import dataclasses
29+
import functools
2830
import importlib.util
2931
import inspect
3032
import io
@@ -37,6 +39,7 @@
3739
import sys
3840
import tempfile
3941
import textwrap
42+
import threading
4043
import typing
4144
import urllib
4245
import zipfile
@@ -95,6 +98,15 @@
9598
func_map = itertools.starmap
9699

97100

101+
async def to_thread(
102+
func: Any, /, *args: Any, **kwargs: Dict[str, Any],
103+
) -> Any:
104+
loop = asyncio.get_running_loop()
105+
ctx = contextvars.copy_context()
106+
func_call = functools.partial(ctx.run, func, *args, **kwargs)
107+
return await loop.run_in_executor(None, func_call)
108+
109+
98110
# Use negative values to indicate unsigned ints / binary data / usec time precision
99111
rowdat_1_type_map = {
100112
'bool': ft.LONGLONG,
@@ -274,11 +286,19 @@ def build_udf_endpoint(
274286
if returns_data_format in ['scalar', 'list']:
275287

276288
async def do_func(
289+
cancel_event: threading.Event,
277290
row_ids: Sequence[int],
278291
rows: Sequence[Sequence[Any]],
279292
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
280293
'''Call function on given rows of data.'''
281-
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
294+
out = []
295+
for row in rows:
296+
if cancel_event.is_set():
297+
raise asyncio.CancelledError(
298+
'Function call was cancelled',
299+
)
300+
out.append(func(*row))
301+
return row_ids, list(zip(out))
282302

283303
return do_func
284304

@@ -309,6 +329,7 @@ def build_vector_udf_endpoint(
309329
array_cls = get_array_class(returns_data_format)
310330

311331
async def do_func(
332+
cancel_event: threading.Event,
312333
row_ids: Sequence[int],
313334
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
314335
) -> Tuple[
@@ -361,6 +382,7 @@ def build_tvf_endpoint(
361382
if returns_data_format in ['scalar', 'list']:
362383

363384
async def do_func(
385+
cancel_event: threading.Event,
364386
row_ids: Sequence[int],
365387
rows: Sequence[Sequence[Any]],
366388
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
@@ -369,6 +391,10 @@ async def do_func(
369391
out = []
370392
# Call function on each row of data
371393
for i, res in zip(row_ids, func_map(func, rows)):
394+
if cancel_event.is_set():
395+
raise asyncio.CancelledError(
396+
'Function call was cancelled',
397+
)
372398
out.extend(as_list_of_tuples(res))
373399
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
374400
return out_ids, out
@@ -402,6 +428,7 @@ def build_vector_tvf_endpoint(
402428
array_cls = get_array_class(returns_data_format)
403429

404430
async def do_func(
431+
cancel_event: threading.Event,
405432
row_ids: Sequence[int],
406433
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
407434
) -> Tuple[
@@ -458,6 +485,7 @@ def make_func(
458485
function_type = sig.get('function_type', 'udf')
459486
args_data_format = sig.get('args_data_format', 'scalar')
460487
returns_data_format = sig.get('returns_data_format', 'scalar')
488+
timeout = sig.get('timeout', get_option('external_function.timeout'))
461489

462490
if function_type == 'tvf':
463491
do_func = build_tvf_endpoint(func, returns_data_format)
@@ -477,6 +505,9 @@ def make_func(
477505
# Set function type
478506
info['function_type'] = function_type
479507

508+
# Set timeout
509+
info['timeout'] = max(timeout, 1)
510+
480511
# Setup argument types for rowdat_1 parser
481512
colspec = []
482513
for x in sig['args']:
@@ -498,6 +529,37 @@ def make_func(
498529
return do_func, info
499530

500531

532+
async def cancel_on_timeout(timeout: int) -> None:
533+
"""Cancel request if it takes too long."""
534+
await asyncio.sleep(timeout)
535+
raise asyncio.CancelledError(
536+
'Function call was cancelled due to timeout',
537+
)
538+
539+
540+
async def cancel_on_disconnect(
541+
receive: Callable[..., Awaitable[Any]],
542+
) -> None:
543+
"""Cancel request if client disconnects."""
544+
while True:
545+
message = await receive()
546+
if message['type'] == 'http.disconnect':
547+
raise asyncio.CancelledError(
548+
'Function call was cancelled by client',
549+
)
550+
551+
552+
def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
553+
"""Cancel all tasks."""
554+
for task in tasks:
555+
if task.done():
556+
continue
557+
try:
558+
task.cancel()
559+
except Exception:
560+
pass
561+
562+
501563
class Application(object):
502564
"""
503565
Create an external function application.
@@ -851,6 +913,8 @@ async def __call__(
851913
more_body = True
852914
while more_body:
853915
request = await receive()
916+
if request['type'] == 'http.disconnect':
917+
raise RuntimeError('client disconnected')
854918
data.append(request['body'])
855919
more_body = request.get('more_body', False)
856920

@@ -859,21 +923,87 @@ async def __call__(
859923
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
860924

861925
try:
862-
out = await func(
863-
*input_handler['load']( # type: ignore
864-
func_info['colspec'], b''.join(data),
926+
result = []
927+
928+
cancel_event = threading.Event()
929+
930+
func_task = asyncio.create_task(
931+
to_thread(
932+
lambda: asyncio.run(
933+
func(
934+
cancel_event,
935+
*input_handler['load']( # type: ignore
936+
func_info['colspec'], b''.join(data),
937+
),
938+
),
939+
),
865940
),
866941
)
942+
disconnect_task = asyncio.create_task(
943+
cancel_on_disconnect(receive),
944+
)
945+
timeout_task = asyncio.create_task(
946+
cancel_on_timeout(func_info['timeout']),
947+
)
948+
949+
all_tasks = [func_task, disconnect_task, timeout_task]
950+
951+
done, pending = await asyncio.wait(
952+
all_tasks, return_when=asyncio.FIRST_COMPLETED,
953+
)
954+
955+
cancel_all_tasks(pending)
956+
957+
for task in done:
958+
if task is disconnect_task:
959+
cancel_event.set()
960+
raise asyncio.CancelledError(
961+
'Function call was cancelled by client disconnect',
962+
)
963+
964+
elif task is timeout_task:
965+
cancel_event.set()
966+
raise asyncio.TimeoutError(
967+
'Function call was cancelled due to timeout',
968+
)
969+
970+
elif task is func_task:
971+
result.extend(task.result())
972+
867973
body = output_handler['dump'](
868-
[x[1] for x in func_info['returns']], *out, # type: ignore
974+
[x[1] for x in func_info['returns']], *result, # type: ignore
869975
)
976+
870977
await send(output_handler['response'])
871978

979+
except asyncio.TimeoutError:
980+
logging.exception(
981+
'Timeout in function call: ' + func_name.decode('utf-8'),
982+
)
983+
body = (
984+
'[TimeoutError] Function call timed out after ' +
985+
str(func_info['timeout']) +
986+
' seconds'
987+
).encode('utf-8')
988+
await send(self.error_response_dict)
989+
990+
except asyncio.CancelledError:
991+
logging.exception(
992+
'Function call cancelled: ' + func_name.decode('utf-8'),
993+
)
994+
body = b'[CancelledError] Function call was cancelled'
995+
await send(self.error_response_dict)
996+
872997
except Exception as e:
873-
logging.exception('Error in function call')
998+
logging.exception(
999+
'Error in function call: ' + func_name.decode('utf-8'),
1000+
)
8741001
body = f'[{type(e).__name__}] {str(e).strip()}'.encode('utf-8')
8751002
await send(self.error_response_dict)
8761003

1004+
finally:
1005+
cancel_all_tasks(all_tasks)
1006+
8771007
# Handle api reflection
8781008
elif method == 'GET' and path == self.show_create_function_path:
8791009
host = headers.get(b'host', b'localhost:80')

0 commit comments

Comments
 (0)