@@ -285,6 +285,8 @@ def build_udf_endpoint(
285
285
"""
286
286
if returns_data_format in ['scalar' , 'list' ]:
287
287
288
+ is_async = asyncio .iscoroutinefunction (func )
289
+
288
290
async def do_func (
289
291
cancel_event : threading .Event ,
290
292
row_ids : Sequence [int ],
@@ -297,7 +299,10 @@ async def do_func(
297
299
raise asyncio .CancelledError (
298
300
'Function call was cancelled' ,
299
301
)
300
- out .append (func (* row ))
302
+ if is_async :
303
+ out .append (await func (* row ))
304
+ else :
305
+ out .append (func (* row ))
301
306
return row_ids , list (zip (out ))
302
307
303
308
return do_func
@@ -327,6 +332,7 @@ def build_vector_udf_endpoint(
327
332
"""
328
333
masks = get_masked_params (func )
329
334
array_cls = get_array_class (returns_data_format )
335
+ is_async = asyncio .iscoroutinefunction (func )
330
336
331
337
async def do_func (
332
338
cancel_event : threading .Event ,
@@ -341,9 +347,15 @@ async def do_func(
341
347
342
348
# Call the function with `cols` as the function parameters
343
349
if cols and cols [0 ]:
344
- out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
350
+ if is_async :
351
+ out = await func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
352
+ else :
353
+ out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
345
354
else :
346
- out = func ()
355
+ if is_async :
356
+ out = await func ()
357
+ else :
358
+ out = func ()
347
359
348
360
# Single masked value
349
361
if isinstance (out , Masked ):
@@ -381,6 +393,8 @@ def build_tvf_endpoint(
381
393
"""
382
394
if returns_data_format in ['scalar' , 'list' ]:
383
395
396
+ is_async = asyncio .iscoroutinefunction (func )
397
+
384
398
async def do_func (
385
399
cancel_event : threading .Event ,
386
400
row_ids : Sequence [int ],
@@ -390,11 +404,15 @@ async def do_func(
390
404
out_ids : List [int ] = []
391
405
out = []
392
406
# Call function on each row of data
393
- for i , res in zip (row_ids , func_map ( func , rows ) ):
407
+ for i , row in zip (row_ids , rows ):
394
408
if cancel_event .is_set ():
395
409
raise asyncio .CancelledError (
396
410
'Function call was cancelled' ,
397
411
)
412
+ if is_async :
413
+ res = await func (* row )
414
+ else :
415
+ res = func (* row )
398
416
out .extend (as_list_of_tuples (res ))
399
417
out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
400
418
return out_ids , out
@@ -440,13 +458,23 @@ async def do_func(
440
458
# each result row, so we just have to use the same
441
459
# row ID for all rows in the result.
442
460
461
+ is_async = asyncio .iscoroutinefunction (func )
462
+
443
463
# Call function on each column of data
444
464
if cols and cols [0 ]:
445
- res = get_dataframe_columns (
446
- func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
447
- )
465
+ if is_async :
466
+ res = get_dataframe_columns (
467
+ await func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
468
+ )
469
+ else :
470
+ res = get_dataframe_columns (
471
+ func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
472
+ )
448
473
else :
449
- res = get_dataframe_columns (func ())
474
+ if is_async :
475
+ res = get_dataframe_columns (await func ())
476
+ else :
477
+ res = get_dataframe_columns (func ())
450
478
451
479
# Generate row IDs
452
480
if isinstance (res [0 ], Masked ):
@@ -508,6 +536,9 @@ def make_func(
508
536
# Set timeout
509
537
info ['timeout' ] = max (timeout , 1 )
510
538
539
+ # Set async flag
540
+ info ['is_async' ] = asyncio .iscoroutinefunction (func )
541
+
511
542
# Setup argument types for rowdat_1 parser
512
543
colspec = []
513
544
for x in sig ['args' ]:
@@ -927,18 +958,28 @@ async def __call__(
927
958
928
959
cancel_event = threading .Event ()
929
960
930
- func_task = asyncio .create_task (
931
- to_thread (
932
- lambda : asyncio .run (
933
- func (
934
- cancel_event ,
935
- * input_handler ['load' ]( # type: ignore
936
- func_info ['colspec' ], b'' .join (data ),
961
+ if func_info ['is_async' ]:
962
+ func_task = asyncio .create_task (
963
+ func (
964
+ cancel_event ,
965
+ * input_handler ['load' ]( # type: ignore
966
+ func_info ['colspec' ], b'' .join (data ),
967
+ ),
968
+ ),
969
+ )
970
+ else :
971
+ func_task = asyncio .create_task (
972
+ to_thread (
973
+ lambda : asyncio .run (
974
+ func (
975
+ cancel_event ,
976
+ * input_handler ['load' ]( # type: ignore
977
+ func_info ['colspec' ], b'' .join (data ),
978
+ ),
937
979
),
938
980
),
939
981
),
940
- ),
941
- )
982
+ )
942
983
disconnect_task = asyncio .create_task (
943
984
cancel_on_disconnect (receive ),
944
985
)
@@ -970,6 +1011,7 @@ async def __call__(
970
1011
elif task is func_task :
971
1012
result .extend (task .result ())
972
1013
1014
+ print (result )
973
1015
body = output_handler ['dump' ](
974
1016
[x [1 ] for x in func_info ['returns' ]], * result , # type: ignore
975
1017
)
0 commit comments