24
24
"""
25
25
import argparse
26
26
import asyncio
27
+ import contextvars
27
28
import dataclasses
29
+ import functools
28
30
import importlib .util
29
31
import inspect
30
32
import io
37
39
import sys
38
40
import tempfile
39
41
import textwrap
42
+ import threading
40
43
import typing
41
44
import urllib
42
45
import zipfile
95
98
func_map = itertools .starmap
96
99
97
100
101
+ async def to_thread (
102
+ func : Any , / , * args : Any , ** kwargs : Dict [str , Any ],
103
+ ) -> Any :
104
+ loop = asyncio .get_running_loop ()
105
+ ctx = contextvars .copy_context ()
106
+ func_call = functools .partial (ctx .run , func , * args , ** kwargs )
107
+ return await loop .run_in_executor (None , func_call )
108
+
109
+
98
110
# Use negative values to indicate unsigned ints / binary data / usec time precision
99
111
rowdat_1_type_map = {
100
112
'bool' : ft .LONGLONG ,
@@ -274,11 +286,19 @@ def build_udf_endpoint(
274
286
if returns_data_format in ['scalar' , 'list' ]:
275
287
276
288
async def do_func (
289
+ cancel_event : threading .Event ,
277
290
row_ids : Sequence [int ],
278
291
rows : Sequence [Sequence [Any ]],
279
292
) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
280
293
'''Call function on given rows of data.'''
281
- return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
294
+ out = []
295
+ for row in rows :
296
+ if cancel_event .is_set ():
297
+ raise asyncio .CancelledError (
298
+ 'Function call was cancelled' ,
299
+ )
300
+ out .append (func (* row ))
301
+ return row_ids , list (zip (out ))
282
302
283
303
return do_func
284
304
@@ -309,6 +329,7 @@ def build_vector_udf_endpoint(
309
329
array_cls = get_array_class (returns_data_format )
310
330
311
331
async def do_func (
332
+ cancel_event : threading .Event ,
312
333
row_ids : Sequence [int ],
313
334
cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
314
335
) -> Tuple [
@@ -361,6 +382,7 @@ def build_tvf_endpoint(
361
382
if returns_data_format in ['scalar' , 'list' ]:
362
383
363
384
async def do_func (
385
+ cancel_event : threading .Event ,
364
386
row_ids : Sequence [int ],
365
387
rows : Sequence [Sequence [Any ]],
366
388
) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
@@ -369,6 +391,10 @@ async def do_func(
369
391
out = []
370
392
# Call function on each row of data
371
393
for i , res in zip (row_ids , func_map (func , rows )):
394
+ if cancel_event .is_set ():
395
+ raise asyncio .CancelledError (
396
+ 'Function call was cancelled' ,
397
+ )
372
398
out .extend (as_list_of_tuples (res ))
373
399
out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
374
400
return out_ids , out
@@ -402,6 +428,7 @@ def build_vector_tvf_endpoint(
402
428
array_cls = get_array_class (returns_data_format )
403
429
404
430
async def do_func (
431
+ cancel_event : threading .Event ,
405
432
row_ids : Sequence [int ],
406
433
cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
407
434
) -> Tuple [
@@ -458,6 +485,7 @@ def make_func(
458
485
function_type = sig .get ('function_type' , 'udf' )
459
486
args_data_format = sig .get ('args_data_format' , 'scalar' )
460
487
returns_data_format = sig .get ('returns_data_format' , 'scalar' )
488
+ timeout = sig .get ('timeout' , get_option ('external_function.timeout' ))
461
489
462
490
if function_type == 'tvf' :
463
491
do_func = build_tvf_endpoint (func , returns_data_format )
@@ -477,6 +505,9 @@ def make_func(
477
505
# Set function type
478
506
info ['function_type' ] = function_type
479
507
508
+ # Set timeout
509
+ info ['timeout' ] = max (timeout , 1 )
510
+
480
511
# Setup argument types for rowdat_1 parser
481
512
colspec = []
482
513
for x in sig ['args' ]:
@@ -498,6 +529,37 @@ def make_func(
498
529
return do_func , info
499
530
500
531
532
+ async def cancel_on_timeout (timeout : int ) -> None :
533
+ """Cancel request if it takes too long."""
534
+ await asyncio .sleep (timeout )
535
+ raise asyncio .CancelledError (
536
+ 'Function call was cancelled due to timeout' ,
537
+ )
538
+
539
+
540
+ async def cancel_on_disconnect (
541
+ receive : Callable [..., Awaitable [Any ]],
542
+ ) -> None :
543
+ """Cancel request if client disconnects."""
544
+ while True :
545
+ message = await receive ()
546
+ if message ['type' ] == 'http.disconnect' :
547
+ raise asyncio .CancelledError (
548
+ 'Function call was cancelled by client' ,
549
+ )
550
+
551
+
552
+ def cancel_all_tasks (tasks : Iterable [asyncio .Task [Any ]]) -> None :
553
+ """Cancel all tasks."""
554
+ for task in tasks :
555
+ if task .done ():
556
+ continue
557
+ try :
558
+ task .cancel ()
559
+ except Exception :
560
+ pass
561
+
562
+
501
563
class Application (object ):
502
564
"""
503
565
Create an external function application.
@@ -851,6 +913,8 @@ async def __call__(
851
913
more_body = True
852
914
while more_body :
853
915
request = await receive ()
916
+ if request ['type' ] == 'http.disconnect' :
917
+ raise RuntimeError ('client disconnected' )
854
918
data .append (request ['body' ])
855
919
more_body = request .get ('more_body' , False )
856
920
@@ -859,21 +923,87 @@ async def __call__(
859
923
output_handler = self .handlers [(accepts , data_version , returns_data_format )]
860
924
861
925
try :
862
- out = await func (
863
- * input_handler ['load' ]( # type: ignore
864
- func_info ['colspec' ], b'' .join (data ),
926
+ result = []
927
+
928
+ cancel_event = threading .Event ()
929
+
930
+ func_task = asyncio .create_task (
931
+ to_thread (
932
+ lambda : asyncio .run (
933
+ func (
934
+ cancel_event ,
935
+ * input_handler ['load' ]( # type: ignore
936
+ func_info ['colspec' ], b'' .join (data ),
937
+ ),
938
+ ),
939
+ ),
865
940
),
866
941
)
942
+ disconnect_task = asyncio .create_task (
943
+ cancel_on_disconnect (receive ),
944
+ )
945
+ timeout_task = asyncio .create_task (
946
+ cancel_on_timeout (func_info ['timeout' ]),
947
+ )
948
+
949
+ all_tasks = [func_task , disconnect_task , timeout_task ]
950
+
951
+ done , pending = await asyncio .wait (
952
+ all_tasks , return_when = asyncio .FIRST_COMPLETED ,
953
+ )
954
+
955
+ cancel_all_tasks (pending )
956
+
957
+ for task in done :
958
+ if task is disconnect_task :
959
+ cancel_event .set ()
960
+ raise asyncio .CancelledError (
961
+ 'Function call was cancelled by client disconnect' ,
962
+ )
963
+
964
+ elif task is timeout_task :
965
+ cancel_event .set ()
966
+ raise asyncio .TimeoutError (
967
+ 'Function call was cancelled due to timeout' ,
968
+ )
969
+
970
+ elif task is func_task :
971
+ result .extend (task .result ())
972
+
867
973
body = output_handler ['dump' ](
868
- [x [1 ] for x in func_info ['returns' ]], * out , # type: ignore
974
+ [x [1 ] for x in func_info ['returns' ]], * result , # type: ignore
869
975
)
976
+
870
977
await send (output_handler ['response' ])
871
978
979
+ except asyncio .TimeoutError :
980
+ logging .exception (
981
+ 'Timeout in function call: ' + func_name .decode ('utf-8' ),
982
+ )
983
+ body = (
984
+ '[TimeoutError] Function call timed out after ' +
985
+ str (func_info ['timeout' ]) +
986
+ ' seconds'
987
+ ).encode ('utf-8' )
988
+ await send (self .error_response_dict )
989
+
990
+ except asyncio .CancelledError :
991
+ logging .exception (
992
+ 'Function call cancelled: ' + func_name .decode ('utf-8' ),
993
+ )
994
+ body = b'[CancelledError] Function call was cancelled'
995
+ await send (self .error_response_dict )
996
+
872
997
except Exception as e :
873
- logging .exception ('Error in function call' )
998
+ logging .exception (
999
+ 'Error in function call: ' + func_name .decode ('utf-8' ),
1000
+ )
874
1001
body = f'[{ type (e ).__name__ } ] { str (e ).strip ()} ' .encode ('utf-8' )
875
1002
await send (self .error_response_dict )
876
1003
1004
+ finally :
1005
+ cancel_all_tasks (all_tasks )
1006
+
877
1007
# Handle api reflection
878
1008
elif method == 'GET' and path == self .show_create_function_path :
879
1009
host = headers .get (b'host' , b'localhost:80' )
0 commit comments