Skip to content

Commit 123cc90

Browse files
committed
Add testing
1 parent effcf41 commit 123cc90

File tree

4 files changed

+198
-33
lines changed

4 files changed

+198
-33
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,27 @@ def build_tuple(x: Any) -> Any:
263263
return tuple(x) if isinstance(x, Masked) else (x, None)
264264

265265

266+
def cancel_on_event(
267+
cancel_event: threading.Event,
268+
) -> None:
269+
"""
270+
Cancel the function call if the cancel event is set.
271+
272+
Parameters
273+
----------
274+
cancel_event : threading.Event
275+
The event to check for cancellation
276+
277+
Raises
278+
------
279+
asyncio.CancelledError
280+
If the cancel event is set
281+
282+
"""
283+
if cancel_event.is_set():
284+
raise asyncio.CancelledError('Function call was cancelled')
285+
286+
266287
def build_udf_endpoint(
267288
func: Callable[..., Any],
268289
returns_data_format: str,
@@ -295,10 +316,7 @@ async def do_func(
295316
'''Call function on given rows of data.'''
296317
out = []
297318
for row in rows:
298-
if cancel_event.is_set():
299-
raise asyncio.CancelledError(
300-
'Function call was cancelled',
301-
)
319+
cancel_on_event(cancel_event)
302320
if is_async:
303321
out.append(await func(*row))
304322
else:
@@ -357,6 +375,8 @@ async def do_func(
357375
else:
358376
out = func()
359377

378+
cancel_on_event(cancel_event)
379+
360380
# Single masked value
361381
if isinstance(out, Masked):
362382
return row_ids, [tuple(out)]
@@ -405,10 +425,7 @@ async def do_func(
405425
out = []
406426
# Call function on each row of data
407427
for i, row in zip(row_ids, rows):
408-
if cancel_event.is_set():
409-
raise asyncio.CancelledError(
410-
'Function call was cancelled',
411-
)
428+
cancel_on_event(cancel_event)
412429
if is_async:
413430
res = await func(*row)
414431
else:
@@ -476,6 +493,8 @@ async def do_func(
476493
else:
477494
res = get_dataframe_columns(func())
478495

496+
cancel_on_event(cancel_event)
497+
479498
# Generate row IDs
480499
if isinstance(res[0], Masked):
481500
row_ids = array_cls([row_ids[0]] * len(res[0][0]))
@@ -513,7 +532,10 @@ def make_func(
513532
function_type = sig.get('function_type', 'udf')
514533
args_data_format = sig.get('args_data_format', 'scalar')
515534
returns_data_format = sig.get('returns_data_format', 'scalar')
516-
timeout = sig.get('timeout', get_option('external_function.timeout'))
535+
timeout = (
536+
func._singlestoredb_attrs.get('timeout') or # type: ignore
537+
get_option('external_function.timeout')
538+
)
517539

518540
if function_type == 'tvf':
519541
do_func = build_tvf_endpoint(func, returns_data_format)
@@ -954,40 +976,31 @@ async def __call__(
954976
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
955977

956978
try:
979+
all_tasks = []
957980
result = []
958981

959982
cancel_event = threading.Event()
960983

961-
if func_info['is_async']:
962-
func_task = asyncio.create_task(
963-
func(
964-
cancel_event,
965-
*input_handler['load']( # type: ignore
966-
func_info['colspec'], b''.join(data),
967-
),
968-
),
969-
)
970-
else:
971-
func_task = asyncio.create_task(
972-
to_thread(
973-
lambda: asyncio.run(
974-
func(
975-
cancel_event,
976-
*input_handler['load']( # type: ignore
977-
func_info['colspec'], b''.join(data),
978-
),
979-
),
980-
),
981-
),
982-
)
984+
func_args = [
985+
cancel_event,
986+
*input_handler['load']( # type: ignore
987+
func_info['colspec'], b''.join(data),
988+
),
989+
]
990+
991+
func_task = asyncio.create_task(
992+
func(*func_args)
993+
if func_info['is_async']
994+
else to_thread(lambda: asyncio.run(func(*func_args))),
995+
)
983996
disconnect_task = asyncio.create_task(
984997
cancel_on_disconnect(receive),
985998
)
986999
timeout_task = asyncio.create_task(
9871000
cancel_on_timeout(func_info['timeout']),
9881001
)
9891002

990-
all_tasks = [func_task, disconnect_task, timeout_task]
1003+
all_tasks += [func_task, disconnect_task, timeout_task]
9911004

9921005
done, pending = await asyncio.wait(
9931006
all_tasks, return_when=asyncio.FIRST_COMPLETED,
@@ -1011,7 +1024,6 @@ async def __call__(
10111024
elif task is func_task:
10121025
result.extend(task.result())
10131026

1014-
print(result)
10151027
body = output_handler['dump'](
10161028
[x[1] for x in func_info['returns']], *result, # type: ignore
10171029
)

singlestoredb/tests/ext_funcs/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22
# mypy: disable-error-code="type-arg"
3+
import asyncio
4+
import time
35
import typing
46
from typing import List
57
from typing import NamedTuple
@@ -36,6 +38,25 @@ def double_mult(x: float, y: float) -> float:
3638
return x * y
3739

3840

41+
@udf(timeout=2)
42+
def timeout_double_mult(x: float, y: float) -> float:
43+
print('TIMEOUT', x, y)
44+
time.sleep(5)
45+
return x * y
46+
47+
48+
@udf
49+
async def async_double_mult(x: float, y: float) -> float:
50+
return x * y
51+
52+
53+
@udf(timeout=2)
54+
async def async_timeout_double_mult(x: float, y: float) -> float:
55+
print('ASYNC TIMEOUT', x, y)
56+
await asyncio.sleep(5)
57+
return x * y
58+
59+
3960
@udf(
4061
args=[DOUBLE(nullable=False), DOUBLE(nullable=False)],
4162
returns=DOUBLE(nullable=False),
@@ -52,6 +73,14 @@ def numpy_double_mult(
5273
return x * y
5374

5475

76+
@udf
77+
async def async_numpy_double_mult(
78+
x: npt.NDArray[np.float64],
79+
y: npt.NDArray[np.float64],
80+
) -> npt.NDArray[np.float64]:
81+
return x * y
82+
83+
5584
@udf(
5685
args=[DOUBLE(nullable=False), DOUBLE(nullable=False)],
5786
returns=DOUBLE(nullable=False),
@@ -537,6 +566,11 @@ def table_function(n: int) -> Table[List[int]]:
537566
return Table([10] * n)
538567

539568

569+
@udf
570+
async def async_table_function(n: int) -> Table[List[int]]:
571+
return Table([10] * n)
572+
573+
540574
@udf(
541575
returns=[
542576
dt.INT(name='c_int', nullable=False),
@@ -594,6 +628,13 @@ def vec_function_df(
594628
return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3]))
595629

596630

631+
@udf(args=VecInputs, returns=DFOutputs)
632+
async def async_vec_function_df(
633+
x: npt.NDArray[np.int_], y: npt.NDArray[np.int_],
634+
) -> Table[pd.DataFrame]:
635+
return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3]))
636+
637+
597638
class MaskOutputs(typing.NamedTuple):
598639
res: Optional[np.int16]
599640

singlestoredb/tests/test.sql

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@ INSERT INTO data SET id='e', name='elephants', value=0;
1414

1515
COMMIT;
1616

17+
CREATE ROWSTORE TABLE IF NOT EXISTS longer_data (
18+
id VARCHAR(255) NOT NULL,
19+
name VARCHAR(255) NOT NULL,
20+
value BIGINT NOT NULL,
21+
PRIMARY KEY (id) USING HASH
22+
) DEFAULT CHARSET = utf8 COLLATE = utf8_unicode_ci;
23+
24+
INSERT INTO longer_data SET id='a', name='antelopes', value=2;
25+
INSERT INTO longer_data SET id='b', name='bears', value=2;
26+
INSERT INTO longer_data SET id='c', name='cats', value=5;
27+
INSERT INTO longer_data SET id='d', name='dogs', value=4;
28+
INSERT INTO longer_data SET id='e', name='elephants', value=0;
29+
INSERT INTO longer_data SET id='f', name='ferrets', value=2;
30+
INSERT INTO longer_data SET id='g', name='gorillas', value=4;
31+
INSERT INTO longer_data SET id='h', name='horses', value=6;
32+
INSERT INTO longer_data SET id='i', name='iguanas', value=2;
33+
INSERT INTO longer_data SET id='j', name='jaguars', value=0;
34+
INSERT INTO longer_data SET id='k', name='kiwis', value=0;
35+
INSERT INTO longer_data SET id='l', name='leopards', value=1;
36+
37+
COMMIT;
38+
1739
CREATE ROWSTORE TABLE IF NOT EXISTS data_with_nulls (
1840
id VARCHAR(255) NOT NULL,
1941
name VARCHAR(255),

singlestoredb/tests/test_ext_func.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,43 @@ def test_double_mult(self):
162162
'from data order by id',
163163
)
164164

165+
def test_timeout_double_mult(self):
166+
with self.assertRaises(self.conn.OperationalError) as exc:
167+
self.cur.execute(
168+
'select timeout_double_mult(value, 100) as res '
169+
'from longer_data order by id',
170+
)
171+
assert 'timeout' in str(exc.exception).lower()
172+
173+
def test_async_double_mult(self):
174+
self.cur.execute(
175+
'select async_double_mult(value, 100) as res from data order by id',
176+
)
177+
178+
assert [tuple(x) for x in self.cur] == \
179+
[(200.0,), (200.0,), (500.0,), (400.0,), (0.0,)]
180+
181+
desc = self.cur.description
182+
assert len(desc) == 1
183+
assert desc[0].name == 'res'
184+
assert desc[0].type_code == ft.DOUBLE
185+
assert desc[0].null_ok is False
186+
187+
# NULL is not valid
188+
with self.assertRaises(self.conn.OperationalError):
189+
self.cur.execute(
190+
'select async_double_mult(value, NULL) as res '
191+
'from data order by id',
192+
)
193+
194+
def test_async_timeout_double_mult(self):
195+
with self.assertRaises(self.conn.OperationalError) as exc:
196+
self.cur.execute(
197+
'select async_timeout_double_mult(value, 100) as res '
198+
'from longer_data order by id',
199+
)
200+
assert 'timeout' in str(exc.exception).lower()
201+
165202
def test_pandas_double_mult(self):
166203
self.cur.execute(
167204
'select pandas_double_mult(value, 100) as res '
@@ -206,6 +243,28 @@ def test_numpy_double_mult(self):
206243
'from data order by id',
207244
)
208245

246+
def test_async_numpy_double_mult(self):
247+
self.cur.execute(
248+
'select async_numpy_double_mult(value, 100) as res '
249+
'from data order by id',
250+
)
251+
252+
assert [tuple(x) for x in self.cur] == \
253+
[(200.0,), (200.0,), (500.0,), (400.0,), (0.0,)]
254+
255+
desc = self.cur.description
256+
assert len(desc) == 1
257+
assert desc[0].name == 'res'
258+
assert desc[0].type_code == ft.DOUBLE
259+
assert desc[0].null_ok is False
260+
261+
# NULL is not valid
262+
with self.assertRaises(self.conn.OperationalError):
263+
self.cur.execute(
264+
'select async_numpy_double_mult(value, NULL) as res '
265+
'from data order by id',
266+
)
267+
209268
def test_arrow_double_mult(self):
210269
self.cur.execute(
211270
'select arrow_double_mult(value, 100) as res '
@@ -1246,6 +1305,17 @@ def test_table_function(self):
12461305
assert desc[0].type_code == ft.LONGLONG
12471306
assert desc[0].null_ok is False
12481307

1308+
def test_async_table_function(self):
1309+
self.cur.execute('select * from async_table_function(5)')
1310+
1311+
assert [x[0] for x in self.cur] == [10, 10, 10, 10, 10]
1312+
1313+
desc = self.cur.description
1314+
assert len(desc) == 1
1315+
assert desc[0].name == 'a'
1316+
assert desc[0].type_code == ft.LONGLONG
1317+
assert desc[0].null_ok is False
1318+
12491319
def test_table_function_tuple(self):
12501320
self.cur.execute('select * from table_function_tuple(3)')
12511321

@@ -1310,6 +1380,26 @@ def test_vec_function_df(self):
13101380
assert desc[1].type_code == ft.DOUBLE
13111381
assert desc[1].null_ok is False
13121382

1383+
def test_async_vec_function_df(self):
1384+
self.cur.execute('select * from async_vec_function_df(5, 10)')
1385+
1386+
out = list(self.cur)
1387+
1388+
assert out == [
1389+
(1, 1.1),
1390+
(2, 2.2),
1391+
(3, 3.3),
1392+
]
1393+
1394+
desc = self.cur.description
1395+
assert len(desc) == 2
1396+
assert desc[0].name == 'res'
1397+
assert desc[0].type_code == ft.SHORT
1398+
assert desc[0].null_ok is False
1399+
assert desc[1].name == 'res2'
1400+
assert desc[1].type_code == ft.DOUBLE
1401+
assert desc[1].null_ok is False
1402+
13131403
def test_vec_function_ints_masked(self):
13141404
self.cur.execute('select * from vec_function_ints_masked(5, 10)')
13151405

0 commit comments

Comments
 (0)