From 463f5d534ca7a3f223e1eac554f7faf5249f0c61 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 5 Jun 2025 09:42:18 -0500 Subject: [PATCH 01/10] Ability to cancel a UDF on disconnect or timeout --- singlestoredb/config.py | 5 + singlestoredb/functions/decorator.py | 2 + singlestoredb/functions/ext/asgi.py | 142 +++++++++++++++++++++++++-- 3 files changed, 143 insertions(+), 6 deletions(-) diff --git a/singlestoredb/config.py b/singlestoredb/config.py index 61b79f29..a5e4805d 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -438,6 +438,11 @@ environ=['SINGLESTOREDB_EXT_FUNC_PORT'], ) +register_option( + 'external_function.timeout', 'int', check_int, 24*60*60, + 'Specifies the timeout in seconds for processing a batch of rows.', + environ=['SINGLESTOREDB_EXT_FUNC_TIMEOUT'], +) # # Debugging options diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 2280ed40..13cc8b8a 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -100,6 +100,7 @@ def _func( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, + timeout: Optional[int] = None, ) -> Callable[..., Any]: """Generic wrapper for UDF and TVF decorators.""" @@ -108,6 +109,7 @@ def _func( name=name, args=expand_types(args), returns=expand_types(returns), + timeout=timeout, ).items() if v is not None } diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index c8ff8f4f..5b421282 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,7 +24,9 @@ """ import argparse import asyncio +import contextvars import dataclasses +import functools import importlib.util import inspect import io @@ -37,6 +39,7 @@ import sys import tempfile import textwrap +import threading import typing import urllib import zipfile @@ -95,6 +98,15 @@ func_map = itertools.starmap +async def to_thread( + func: Any, /, *args: Any, **kwargs: Dict[str, Any], +) -> Any: + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + # Use negative values to indicate unsigned ints / binary data / usec time precision rowdat_1_type_map = { 'bool': ft.LONGLONG, @@ -274,11 +286,19 @@ def build_udf_endpoint( if returns_data_format in ['scalar', 'list']: async def do_func( + cancel_event: threading.Event, row_ids: Sequence[int], rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))] + out = [] + for row in rows: + if cancel_event.is_set(): + raise asyncio.CancelledError( + 'Function call was cancelled', + ) + out.append(func(*row)) + return row_ids, list(zip(out)) return do_func @@ -309,6 +329,7 @@ def build_vector_udf_endpoint( array_cls = get_array_class(returns_data_format) async def do_func( + cancel_event: threading.Event, row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], ) -> Tuple[ @@ -361,6 +382,7 @@ def build_tvf_endpoint( if returns_data_format in ['scalar', 'list']: async def do_func( + cancel_event: threading.Event, row_ids: Sequence[int], rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: @@ -369,6 +391,10 @@ async def do_func( out = [] # Call function on each row of data for i, res in zip(row_ids, func_map(func, rows)): + if cancel_event.is_set(): + raise asyncio.CancelledError( + 'Function call was cancelled', + ) out.extend(as_list_of_tuples(res)) out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out @@ -402,6 +428,7 @@ def build_vector_tvf_endpoint( array_cls = get_array_class(returns_data_format) async def do_func( + cancel_event: threading.Event, row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], ) -> Tuple[ @@ -458,6 +485,7 @@ def make_func( function_type = sig.get('function_type', 'udf') args_data_format = sig.get('args_data_format', 'scalar') returns_data_format = sig.get('returns_data_format', 'scalar') + timeout = sig.get('timeout', get_option('external_function.timeout')) if function_type == 'tvf': do_func = build_tvf_endpoint(func, returns_data_format) @@ -477,6 +505,9 @@ def make_func( # Set function type info['function_type'] = function_type + # Set timeout + info['timeout'] = max(timeout, 1) + # Setup argument types for rowdat_1 parser colspec = [] for x in sig['args']: @@ -498,6 +529,37 @@ def make_func( return do_func, info +async def cancel_on_timeout(timeout: int) -> None: + """Cancel request if it takes too long.""" + await asyncio.sleep(timeout) + raise asyncio.CancelledError( + 'Function call was cancelled due to timeout', + ) + + +async def cancel_on_disconnect( + receive: Callable[..., Awaitable[Any]], +) -> None: + """Cancel request if client disconnects.""" + while True: + message = await receive() + if message['type'] == 'http.disconnect': + raise asyncio.CancelledError( + 'Function call was cancelled by client', + ) + + +def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None: + """Cancel all tasks.""" + for task in tasks: + if task.done(): + continue + try: + task.cancel() + except Exception: + pass + + class Application(object): """ Create an external function application. @@ -851,6 +913,8 @@ async def __call__( more_body = True while more_body: request = await receive() + if request['type'] == 'http.disconnect': + raise RuntimeError('client disconnected') data.append(request['body']) more_body = request.get('more_body', False) @@ -859,21 +923,87 @@ async def __call__( output_handler = self.handlers[(accepts, data_version, returns_data_format)] try: - out = await func( - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), + result = [] + + cancel_event = threading.Event() + + func_task = asyncio.create_task( + to_thread( + lambda: asyncio.run( + func( + cancel_event, + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), + ), + ), ), ) + disconnect_task = asyncio.create_task( + cancel_on_disconnect(receive), + ) + timeout_task = asyncio.create_task( + cancel_on_timeout(func_info['timeout']), + ) + + all_tasks = [func_task, disconnect_task, timeout_task] + + done, pending = await asyncio.wait( + all_tasks, return_when=asyncio.FIRST_COMPLETED, + ) + + cancel_all_tasks(pending) + + for task in done: + if task is disconnect_task: + cancel_event.set() + raise asyncio.CancelledError( + 'Function call was cancelled by client disconnect', + ) + + elif task is timeout_task: + cancel_event.set() + raise asyncio.TimeoutError( + 'Function call was cancelled due to timeout', + ) + + elif task is func_task: + result.extend(task.result()) + body = output_handler['dump']( - [x[1] for x in func_info['returns']], *out, # type: ignore + [x[1] for x in func_info['returns']], *result, # type: ignore ) + await send(output_handler['response']) + except asyncio.TimeoutError: + logging.exception( + 'Timeout in function call: ' + func_name.decode('utf-8'), + ) + body = ( + '[TimeoutError] Function call timed out after ' + + str(func_info['timeout']) + + ' seconds' + ).encode('utf-8') + await send(self.error_response_dict) + + except asyncio.CancelledError: + logging.exception( + 'Function call cancelled: ' + func_name.decode('utf-8'), + ) + body = b'[CancelledError] Function call was cancelled' + await send(self.error_response_dict) + except Exception as e: - logging.exception('Error in function call') + logging.exception( + 'Error in function call: ' + func_name.decode('utf-8'), + ) body = f'[{type(e).__name__}] {str(e).strip()}'.encode('utf-8') await send(self.error_response_dict) + finally: + cancel_all_tasks(all_tasks) + # Handle api reflection elif method == 'GET' and path == self.show_create_function_path: host = headers.get(b'host', b'localhost:80') From 1bfec18640a36b2e8438f9dc887647d70700390f Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 5 Jun 2025 09:47:40 -0500 Subject: [PATCH 02/10] Add missing decorator arg --- singlestoredb/functions/decorator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 13cc8b8a..4eac7470 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -142,6 +142,7 @@ def udf( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, + timeout: Optional[int] = None, ) -> Callable[..., Any]: """ Define a user-defined function (UDF). @@ -169,6 +170,9 @@ def udf( Specifies the return data type of the function. This parameter works the same way as `args`. If the function is a table-valued function, the return type should be a `Table` object. + timeout : int, optional + The timeout in seconds for the UDF execution. If not specified, + the global default timeout is used. Returns ------- @@ -180,4 +184,5 @@ def udf( name=name, args=args, returns=returns, + timeout=timeout, ) From effcf41111e12c49bc54998261ceee28adccfad7 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 5 Jun 2025 13:45:51 -0500 Subject: [PATCH 03/10] Add async support --- singlestoredb/functions/decorator.py | 38 +++++++++----- singlestoredb/functions/ext/asgi.py | 76 +++++++++++++++++++++------- 2 files changed, 84 insertions(+), 30 deletions(-) diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 4eac7470..68721136 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -1,3 +1,4 @@ +import asyncio import functools import inspect from typing import Any @@ -19,6 +20,7 @@ ] ReturnType = ParameterType +UDFType = Callable[..., Any] def is_valid_type(obj: Any) -> bool: @@ -101,7 +103,7 @@ def _func( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, timeout: Optional[int] = None, -) -> Callable[..., Any]: +) -> UDFType: """Generic wrapper for UDF and TVF decorators.""" _singlestoredb_attrs = { # type: ignore @@ -117,23 +119,33 @@ def _func( # called later, so the wrapper much be created with the func passed # in at that time. if func is None: - def decorate(func: Callable[..., Any]) -> Callable[..., Any]: + def decorate(func: UDFType) -> UDFType: - def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: - return func(*args, **kwargs) # type: ignore + if asyncio.iscoroutinefunction(func): + async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType: + return await func(*args, **kwargs) # type: ignore + async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(async_wrapper) - wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore - - return functools.wraps(func)(wrapper) + else: + def wrapper(*args: Any, **kwargs: Any) -> UDFType: + return func(*args, **kwargs) # type: ignore + wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(wrapper) return decorate - def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: - return func(*args, **kwargs) # type: ignore - - wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + if asyncio.iscoroutinefunction(func): + async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType: + return await func(*args, **kwargs) # type: ignore + async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(async_wrapper) - return functools.wraps(func)(wrapper) + else: + def wrapper(*args: Any, **kwargs: Any) -> UDFType: + return func(*args, **kwargs) # type: ignore + wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(wrapper) def udf( @@ -143,7 +155,7 @@ def udf( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, timeout: Optional[int] = None, -) -> Callable[..., Any]: +) -> UDFType: """ Define a user-defined function (UDF). diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 5b421282..986f3bd9 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -285,6 +285,8 @@ def build_udf_endpoint( """ if returns_data_format in ['scalar', 'list']: + is_async = asyncio.iscoroutinefunction(func) + async def do_func( cancel_event: threading.Event, row_ids: Sequence[int], @@ -297,7 +299,10 @@ async def do_func( raise asyncio.CancelledError( 'Function call was cancelled', ) - out.append(func(*row)) + if is_async: + out.append(await func(*row)) + else: + out.append(func(*row)) return row_ids, list(zip(out)) return do_func @@ -327,6 +332,7 @@ def build_vector_udf_endpoint( """ masks = get_masked_params(func) array_cls = get_array_class(returns_data_format) + is_async = asyncio.iscoroutinefunction(func) async def do_func( cancel_event: threading.Event, @@ -341,9 +347,15 @@ async def do_func( # Call the function with `cols` as the function parameters if cols and cols[0]: - out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) + if is_async: + out = await func(*[x if m else x[0] for x, m in zip(cols, masks)]) + else: + out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) else: - out = func() + if is_async: + out = await func() + else: + out = func() # Single masked value if isinstance(out, Masked): @@ -381,6 +393,8 @@ def build_tvf_endpoint( """ if returns_data_format in ['scalar', 'list']: + is_async = asyncio.iscoroutinefunction(func) + async def do_func( cancel_event: threading.Event, row_ids: Sequence[int], @@ -390,11 +404,15 @@ async def do_func( out_ids: List[int] = [] out = [] # Call function on each row of data - for i, res in zip(row_ids, func_map(func, rows)): + for i, row in zip(row_ids, rows): if cancel_event.is_set(): raise asyncio.CancelledError( 'Function call was cancelled', ) + if is_async: + res = await func(*row) + else: + res = func(*row) out.extend(as_list_of_tuples(res)) out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out @@ -440,13 +458,23 @@ async def do_func( # each result row, so we just have to use the same # row ID for all rows in the result. + is_async = asyncio.iscoroutinefunction(func) + # Call function on each column of data if cols and cols[0]: - res = get_dataframe_columns( - func(*[x if m else x[0] for x, m in zip(cols, masks)]), - ) + if is_async: + res = get_dataframe_columns( + await func(*[x if m else x[0] for x, m in zip(cols, masks)]), + ) + else: + res = get_dataframe_columns( + func(*[x if m else x[0] for x, m in zip(cols, masks)]), + ) else: - res = get_dataframe_columns(func()) + if is_async: + res = get_dataframe_columns(await func()) + else: + res = get_dataframe_columns(func()) # Generate row IDs if isinstance(res[0], Masked): @@ -508,6 +536,9 @@ def make_func( # Set timeout info['timeout'] = max(timeout, 1) + # Set async flag + info['is_async'] = asyncio.iscoroutinefunction(func) + # Setup argument types for rowdat_1 parser colspec = [] for x in sig['args']: @@ -927,18 +958,28 @@ async def __call__( cancel_event = threading.Event() - func_task = asyncio.create_task( - to_thread( - lambda: asyncio.run( - func( - cancel_event, - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), + if func_info['is_async']: + func_task = asyncio.create_task( + func( + cancel_event, + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), + ), + ) + else: + func_task = asyncio.create_task( + to_thread( + lambda: asyncio.run( + func( + cancel_event, + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), ), ), ), - ), - ) + ) disconnect_task = asyncio.create_task( cancel_on_disconnect(receive), ) @@ -970,6 +1011,7 @@ async def __call__( elif task is func_task: result.extend(task.result()) + print(result) body = output_handler['dump']( [x[1] for x in func_info['returns']], *result, # type: ignore ) From 123cc906835277d0987890433afeeacb780f822e Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 6 Jun 2025 10:23:09 -0500 Subject: [PATCH 04/10] Add testing --- singlestoredb/functions/ext/asgi.py | 78 +++++++++++--------- singlestoredb/tests/ext_funcs/__init__.py | 41 +++++++++++ singlestoredb/tests/test.sql | 22 ++++++ singlestoredb/tests/test_ext_func.py | 90 +++++++++++++++++++++++ 4 files changed, 198 insertions(+), 33 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 986f3bd9..d3ee596d 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -263,6 +263,27 @@ def build_tuple(x: Any) -> Any: return tuple(x) if isinstance(x, Masked) else (x, None) +def cancel_on_event( + cancel_event: threading.Event, +) -> None: + """ + Cancel the function call if the cancel event is set. + + Parameters + ---------- + cancel_event : threading.Event + The event to check for cancellation + + Raises + ------ + asyncio.CancelledError + If the cancel event is set + + """ + if cancel_event.is_set(): + raise asyncio.CancelledError('Function call was cancelled') + + def build_udf_endpoint( func: Callable[..., Any], returns_data_format: str, @@ -295,10 +316,7 @@ async def do_func( '''Call function on given rows of data.''' out = [] for row in rows: - if cancel_event.is_set(): - raise asyncio.CancelledError( - 'Function call was cancelled', - ) + cancel_on_event(cancel_event) if is_async: out.append(await func(*row)) else: @@ -357,6 +375,8 @@ async def do_func( else: out = func() + cancel_on_event(cancel_event) + # Single masked value if isinstance(out, Masked): return row_ids, [tuple(out)] @@ -405,10 +425,7 @@ async def do_func( out = [] # Call function on each row of data for i, row in zip(row_ids, rows): - if cancel_event.is_set(): - raise asyncio.CancelledError( - 'Function call was cancelled', - ) + cancel_on_event(cancel_event) if is_async: res = await func(*row) else: @@ -476,6 +493,8 @@ async def do_func( else: res = get_dataframe_columns(func()) + cancel_on_event(cancel_event) + # Generate row IDs if isinstance(res[0], Masked): row_ids = array_cls([row_ids[0]] * len(res[0][0])) @@ -513,7 +532,10 @@ def make_func( function_type = sig.get('function_type', 'udf') args_data_format = sig.get('args_data_format', 'scalar') returns_data_format = sig.get('returns_data_format', 'scalar') - timeout = sig.get('timeout', get_option('external_function.timeout')) + timeout = ( + func._singlestoredb_attrs.get('timeout') or # type: ignore + get_option('external_function.timeout') + ) if function_type == 'tvf': do_func = build_tvf_endpoint(func, returns_data_format) @@ -954,32 +976,23 @@ async def __call__( output_handler = self.handlers[(accepts, data_version, returns_data_format)] try: + all_tasks = [] result = [] cancel_event = threading.Event() - if func_info['is_async']: - func_task = asyncio.create_task( - func( - cancel_event, - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), - ), - ), - ) - else: - func_task = asyncio.create_task( - to_thread( - lambda: asyncio.run( - func( - cancel_event, - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), - ), - ), - ), - ), - ) + func_args = [ + cancel_event, + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), + ] + + func_task = asyncio.create_task( + func(*func_args) + if func_info['is_async'] + else to_thread(lambda: asyncio.run(func(*func_args))), + ) disconnect_task = asyncio.create_task( cancel_on_disconnect(receive), ) @@ -987,7 +1000,7 @@ async def __call__( cancel_on_timeout(func_info['timeout']), ) - all_tasks = [func_task, disconnect_task, timeout_task] + all_tasks += [func_task, disconnect_task, timeout_task] done, pending = await asyncio.wait( all_tasks, return_when=asyncio.FIRST_COMPLETED, @@ -1011,7 +1024,6 @@ async def __call__( elif task is func_task: result.extend(task.result()) - print(result) body = output_handler['dump']( [x[1] for x in func_info['returns']], *result, # type: ignore ) diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index d481af9e..dd5816fe 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # mypy: disable-error-code="type-arg" +import asyncio +import time import typing from typing import List from typing import NamedTuple @@ -36,6 +38,25 @@ def double_mult(x: float, y: float) -> float: return x * y +@udf(timeout=2) +def timeout_double_mult(x: float, y: float) -> float: + print('TIMEOUT', x, y) + time.sleep(5) + return x * y + + +@udf +async def async_double_mult(x: float, y: float) -> float: + return x * y + + +@udf(timeout=2) +async def async_timeout_double_mult(x: float, y: float) -> float: + print('ASYNC TIMEOUT', x, y) + await asyncio.sleep(5) + return x * y + + @udf( args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], returns=DOUBLE(nullable=False), @@ -52,6 +73,14 @@ def numpy_double_mult( return x * y +@udf +async def async_numpy_double_mult( + x: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], +) -> npt.NDArray[np.float64]: + return x * y + + @udf( args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], returns=DOUBLE(nullable=False), @@ -537,6 +566,11 @@ def table_function(n: int) -> Table[List[int]]: return Table([10] * n) +@udf +async def async_table_function(n: int) -> Table[List[int]]: + return Table([10] * n) + + @udf( returns=[ dt.INT(name='c_int', nullable=False), @@ -594,6 +628,13 @@ def vec_function_df( return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) +@udf(args=VecInputs, returns=DFOutputs) +async def async_vec_function_df( + x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], +) -> Table[pd.DataFrame]: + return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) + + class MaskOutputs(typing.NamedTuple): res: Optional[np.int16] diff --git a/singlestoredb/tests/test.sql b/singlestoredb/tests/test.sql index fd7d26f2..ab3cf955 100644 --- a/singlestoredb/tests/test.sql +++ b/singlestoredb/tests/test.sql @@ -14,6 +14,28 @@ INSERT INTO data SET id='e', name='elephants', value=0; COMMIT; +CREATE ROWSTORE TABLE IF NOT EXISTS longer_data ( + id VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + value BIGINT NOT NULL, + PRIMARY KEY (id) USING HASH +) DEFAULT CHARSET = utf8 COLLATE = utf8_unicode_ci; + +INSERT INTO longer_data SET id='a', name='antelopes', value=2; +INSERT INTO longer_data SET id='b', name='bears', value=2; +INSERT INTO longer_data SET id='c', name='cats', value=5; +INSERT INTO longer_data SET id='d', name='dogs', value=4; +INSERT INTO longer_data SET id='e', name='elephants', value=0; +INSERT INTO longer_data SET id='f', name='ferrets', value=2; +INSERT INTO longer_data SET id='g', name='gorillas', value=4; +INSERT INTO longer_data SET id='h', name='horses', value=6; +INSERT INTO longer_data SET id='i', name='iguanas', value=2; +INSERT INTO longer_data SET id='j', name='jaguars', value=0; +INSERT INTO longer_data SET id='k', name='kiwis', value=0; +INSERT INTO longer_data SET id='l', name='leopards', value=1; + +COMMIT; + CREATE ROWSTORE TABLE IF NOT EXISTS data_with_nulls ( id VARCHAR(255) NOT NULL, name VARCHAR(255), diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index 60e1ecf2..d3e680e5 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -162,6 +162,43 @@ def test_double_mult(self): 'from data order by id', ) + def test_timeout_double_mult(self): + with self.assertRaises(self.conn.OperationalError) as exc: + self.cur.execute( + 'select timeout_double_mult(value, 100) as res ' + 'from longer_data order by id', + ) + assert 'timeout' in str(exc.exception).lower() + + def test_async_double_mult(self): + self.cur.execute( + 'select async_double_mult(value, 100) as res from data order by id', + ) + + assert [tuple(x) for x in self.cur] == \ + [(200.0,), (200.0,), (500.0,), (400.0,), (0.0,)] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.DOUBLE + assert desc[0].null_ok is False + + # NULL is not valid + with self.assertRaises(self.conn.OperationalError): + self.cur.execute( + 'select async_double_mult(value, NULL) as res ' + 'from data order by id', + ) + + def test_async_timeout_double_mult(self): + with self.assertRaises(self.conn.OperationalError) as exc: + self.cur.execute( + 'select async_timeout_double_mult(value, 100) as res ' + 'from longer_data order by id', + ) + assert 'timeout' in str(exc.exception).lower() + def test_pandas_double_mult(self): self.cur.execute( 'select pandas_double_mult(value, 100) as res ' @@ -206,6 +243,28 @@ def test_numpy_double_mult(self): 'from data order by id', ) + def test_async_numpy_double_mult(self): + self.cur.execute( + 'select async_numpy_double_mult(value, 100) as res ' + 'from data order by id', + ) + + assert [tuple(x) for x in self.cur] == \ + [(200.0,), (200.0,), (500.0,), (400.0,), (0.0,)] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.DOUBLE + assert desc[0].null_ok is False + + # NULL is not valid + with self.assertRaises(self.conn.OperationalError): + self.cur.execute( + 'select async_numpy_double_mult(value, NULL) as res ' + 'from data order by id', + ) + def test_arrow_double_mult(self): self.cur.execute( 'select arrow_double_mult(value, 100) as res ' @@ -1246,6 +1305,17 @@ def test_table_function(self): assert desc[0].type_code == ft.LONGLONG assert desc[0].null_ok is False + def test_async_table_function(self): + self.cur.execute('select * from async_table_function(5)') + + assert [x[0] for x in self.cur] == [10, 10, 10, 10, 10] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'a' + assert desc[0].type_code == ft.LONGLONG + assert desc[0].null_ok is False + def test_table_function_tuple(self): self.cur.execute('select * from table_function_tuple(3)') @@ -1310,6 +1380,26 @@ def test_vec_function_df(self): assert desc[1].type_code == ft.DOUBLE assert desc[1].null_ok is False + def test_async_vec_function_df(self): + self.cur.execute('select * from async_vec_function_df(5, 10)') + + out = list(self.cur) + + assert out == [ + (1, 1.1), + (2, 2.2), + (3, 3.3), + ] + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is False + assert desc[1].name == 'res2' + assert desc[1].type_code == ft.DOUBLE + assert desc[1].null_ok is False + def test_vec_function_ints_masked(self): self.cur.execute('select * from vec_function_ints_masked(5, 10)') From a158cb687c001e89bfeedef4d445cb05b8041e6a Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 6 Jun 2025 10:50:35 -0500 Subject: [PATCH 05/10] Add type aliases --- .../{typing.py => typing/__init__.py} | 0 singlestoredb/functions/typing/numpy.py | 20 +++ singlestoredb/functions/typing/pandas.py | 2 + singlestoredb/functions/typing/polars.py | 2 + singlestoredb/functions/typing/pyarrow.py | 2 + singlestoredb/tests/ext_funcs/__init__.py | 124 +++++++++--------- 6 files changed, 87 insertions(+), 63 deletions(-) rename singlestoredb/functions/{typing.py => typing/__init__.py} (100%) create mode 100644 singlestoredb/functions/typing/numpy.py create mode 100644 singlestoredb/functions/typing/pandas.py create mode 100644 singlestoredb/functions/typing/polars.py create mode 100644 singlestoredb/functions/typing/pyarrow.py diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing/__init__.py similarity index 100% rename from singlestoredb/functions/typing.py rename to singlestoredb/functions/typing/__init__.py diff --git a/singlestoredb/functions/typing/numpy.py b/singlestoredb/functions/typing/numpy.py new file mode 100644 index 00000000..fb3954d2 --- /dev/null +++ b/singlestoredb/functions/typing/numpy.py @@ -0,0 +1,20 @@ +import numpy as np +import numpy.typing as npt + +NDArray = npt.NDArray + +StringArray = StrArray = npt.NDArray[np.str_] +BytesArray = npt.NDArray[np.bytes_] +Float32Array = FloatArray = npt.NDArray[np.float32] +Float64Array = DoubleArray = npt.NDArray[np.float64] +IntArray = npt.NDArray[np.int_] +Int8Array = npt.NDArray[np.int8] +Int16Array = npt.NDArray[np.int16] +Int32Array = npt.NDArray[np.int32] +Int64Array = npt.NDArray[np.int64] +UInt8Array = npt.NDArray[np.uint8] +UInt16Array = npt.NDArray[np.uint16] +UInt32Array = npt.NDArray[np.uint32] +UInt64Array = npt.NDArray[np.uint64] +DateTimeArray = npt.NDArray[np.datetime64] +TimeDeltaArray = npt.NDArray[np.timedelta64] diff --git a/singlestoredb/functions/typing/pandas.py b/singlestoredb/functions/typing/pandas.py new file mode 100644 index 00000000..23a662c5 --- /dev/null +++ b/singlestoredb/functions/typing/pandas.py @@ -0,0 +1,2 @@ +from pandas import DataFrame # noqa: F401 +from pandas import Series # noqa: F401 diff --git a/singlestoredb/functions/typing/polars.py b/singlestoredb/functions/typing/polars.py new file mode 100644 index 00000000..d7556a1e --- /dev/null +++ b/singlestoredb/functions/typing/polars.py @@ -0,0 +1,2 @@ +from polars import DataFrame # noqa: F401 +from polars import Series # noqa: F401 diff --git a/singlestoredb/functions/typing/pyarrow.py b/singlestoredb/functions/typing/pyarrow.py new file mode 100644 index 00000000..7c7fce94 --- /dev/null +++ b/singlestoredb/functions/typing/pyarrow.py @@ -0,0 +1,2 @@ +from pyarrow import Array # noqa: F401 +from pyarrow import Table # noqa: F401 diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index dd5816fe..f5ea9e41 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -9,10 +9,6 @@ from typing import Tuple import numpy as np -import numpy.typing as npt -import pandas as pd -import polars as pl -import pyarrow as pa import singlestoredb.functions.dtypes as dt from singlestoredb.functions import Masked @@ -26,6 +22,10 @@ from singlestoredb.functions.dtypes import SMALLINT from singlestoredb.functions.dtypes import TEXT from singlestoredb.functions.dtypes import TINYINT +from singlestoredb.functions.typing import numpy as npt +from singlestoredb.functions.typing import pandas as pdt +from singlestoredb.functions.typing import polars as plt +from singlestoredb.functions.typing import pyarrow as pat @udf @@ -40,7 +40,6 @@ def double_mult(x: float, y: float) -> float: @udf(timeout=2) def timeout_double_mult(x: float, y: float) -> float: - print('TIMEOUT', x, y) time.sleep(5) return x * y @@ -52,7 +51,6 @@ async def async_double_mult(x: float, y: float) -> float: @udf(timeout=2) async def async_timeout_double_mult(x: float, y: float) -> float: - print('ASYNC TIMEOUT', x, y) await asyncio.sleep(5) return x * y @@ -61,23 +59,23 @@ async def async_timeout_double_mult(x: float, y: float) -> float: args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], returns=DOUBLE(nullable=False), ) -def pandas_double_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_double_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @udf def numpy_double_mult( - x: npt.NDArray[np.float64], - y: npt.NDArray[np.float64], -) -> npt.NDArray[np.float64]: + x: npt.Float64Array, + y: npt.Float64Array, +) -> npt.Float64Array: return x * y @udf async def async_numpy_double_mult( - x: npt.NDArray[np.float64], - y: npt.NDArray[np.float64], -) -> npt.NDArray[np.float64]: + x: npt.Float64Array, + y: npt.Float64Array, +) -> npt.Float64Array: return x * y @@ -85,7 +83,7 @@ async def async_numpy_double_mult( args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], returns=DOUBLE(nullable=False), ) -def arrow_double_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_double_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -94,7 +92,7 @@ def arrow_double_mult(x: pa.Array, y: pa.Array) -> pa.Array: args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], returns=DOUBLE(nullable=False), ) -def polars_double_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_double_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -135,12 +133,12 @@ def tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @tinyint_udf -def pandas_tinyint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_tinyint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @tinyint_udf -def polars_tinyint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_tinyint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -150,7 +148,7 @@ def numpy_tinyint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @tinyint_udf -def arrow_tinyint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_tinyint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -173,12 +171,12 @@ def smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @smallint_udf -def pandas_smallint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_smallint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @smallint_udf -def polars_smallint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_smallint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -188,7 +186,7 @@ def numpy_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @smallint_udf -def arrow_smallint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -212,12 +210,12 @@ def mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @mediumint_udf -def pandas_mediumint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_mediumint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @mediumint_udf -def polars_mediumint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_mediumint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -227,7 +225,7 @@ def numpy_mediumint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @mediumint_udf -def arrow_mediumint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_mediumint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -251,12 +249,12 @@ def bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @bigint_udf -def pandas_bigint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_bigint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @bigint_udf -def polars_bigint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_bigint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -266,7 +264,7 @@ def numpy_bigint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @bigint_udf -def arrow_bigint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_bigint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -290,12 +288,12 @@ def nullable_tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @nullable_tinyint_udf -def pandas_nullable_tinyint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_nullable_tinyint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @nullable_tinyint_udf -def polars_nullable_tinyint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_nullable_tinyint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -305,7 +303,7 @@ def numpy_nullable_tinyint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @nullable_tinyint_udf -def arrow_nullable_tinyint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_nullable_tinyint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -328,12 +326,12 @@ def nullable_smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @nullable_smallint_udf -def pandas_nullable_smallint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_nullable_smallint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @nullable_smallint_udf -def polars_nullable_smallint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_nullable_smallint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -343,7 +341,7 @@ def numpy_nullable_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @nullable_smallint_udf -def arrow_nullable_smallint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_nullable_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -367,12 +365,12 @@ def nullable_mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int] @nullable_mediumint_udf -def pandas_nullable_mediumint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_nullable_mediumint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @nullable_mediumint_udf -def polars_nullable_mediumint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_nullable_mediumint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -382,7 +380,7 @@ def numpy_nullable_mediumint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @nullable_mediumint_udf -def arrow_nullable_mediumint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_nullable_mediumint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -406,12 +404,12 @@ def nullable_bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @nullable_bigint_udf -def pandas_nullable_bigint_mult(x: pd.Series, y: pd.Series) -> pd.Series: +def pandas_nullable_bigint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: return x * y @nullable_bigint_udf -def polars_nullable_bigint_mult(x: pl.Series, y: pl.Series) -> pl.Series: +def polars_nullable_bigint_mult(x: plt.Series, y: plt.Series) -> plt.Series: return x * y @@ -421,7 +419,7 @@ def numpy_nullable_bigint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: @nullable_bigint_udf -def arrow_nullable_bigint_mult(x: pa.Array, y: pa.Array) -> pa.Array: +def arrow_nullable_bigint_mult(x: pat.Array, y: pat.Array) -> pat.Array: import pyarrow.compute as pc return pc.multiply(x, y) @@ -439,7 +437,7 @@ def string_mult(x: str, times: int) -> str: @udf(args=[TEXT(nullable=False), BIGINT(nullable=False)], returns=TEXT(nullable=False)) -def pandas_string_mult(x: pd.Series, times: pd.Series) -> pd.Series: +def pandas_string_mult(x: pdt.Series, times: pdt.Series) -> pdt.Series: return x * times @@ -476,8 +474,8 @@ def nullable_string_mult(x: Optional[str], times: Optional[int]) -> Optional[str returns=TINYINT(nullable=True), ) def pandas_nullable_tinyint_mult_with_masks( - x: Masked[pd.Series], y: Masked[pd.Series], -) -> Masked[pd.Series]: + x: Masked[pdt.Series], y: Masked[pdt.Series], +) -> Masked[pdt.Series]: x_data, x_nulls = x y_data, y_nulls = y return Masked(x_data * y_data, x_nulls | y_nulls) @@ -497,8 +495,8 @@ def numpy_nullable_tinyint_mult_with_masks( returns=TINYINT(nullable=True), ) def polars_nullable_tinyint_mult_with_masks( - x: Masked[pl.Series], y: Masked[pl.Series], -) -> Masked[pl.Series]: + x: Masked[plt.Series], y: Masked[plt.Series], +) -> Masked[plt.Series]: x_data, x_nulls = x y_data, y_nulls = y return Masked(x_data * y_data, x_nulls | y_nulls) @@ -509,8 +507,8 @@ def polars_nullable_tinyint_mult_with_masks( returns=TINYINT(nullable=True), ) def arrow_nullable_tinyint_mult_with_masks( - x: Masked[pa.Array], y: Masked[pa.Array], -) -> Masked[pa.Array]: + x: Masked[pat.Array], y: Masked[pat.Array], +) -> Masked[pat.Array]: import pyarrow.compute as pc x_data, x_nulls = x y_data, y_nulls = y @@ -518,7 +516,7 @@ def arrow_nullable_tinyint_mult_with_masks( @udf(returns=[TEXT(nullable=False, name='res')]) -def numpy_fixed_strings() -> Table[npt.NDArray[np.str_]]: +def numpy_fixed_strings() -> Table[npt.StrArray]: out = np.array( [ 'hello', @@ -531,7 +529,7 @@ def numpy_fixed_strings() -> Table[npt.NDArray[np.str_]]: @udf(returns=[TEXT(nullable=False, name='res'), TINYINT(nullable=False, name='res2')]) -def numpy_fixed_strings_2() -> Table[npt.NDArray[np.str_], npt.NDArray[np.int8]]: +def numpy_fixed_strings_2() -> Table[npt.StrArray, npt.Int8Array]: out = np.array( [ 'hello', @@ -544,7 +542,7 @@ def numpy_fixed_strings_2() -> Table[npt.NDArray[np.str_], npt.NDArray[np.int8]] @udf(returns=[BLOB(nullable=False, name='res')]) -def numpy_fixed_binary() -> Table[npt.NDArray[np.bytes_]]: +def numpy_fixed_binary() -> Table[npt.BytesArray]: out = np.array( [ 'hello'.encode('utf8'), @@ -595,8 +593,8 @@ def table_function_struct(n: int) -> Table[List[MyTable]]: @udf def vec_function( - x: npt.NDArray[np.float64], y: npt.NDArray[np.float64], -) -> npt.NDArray[np.float64]: + x: npt.Float64Array, y: npt.Float64Array, +) -> npt.Float64Array: return x * y @@ -611,8 +609,8 @@ class VecOutputs(typing.NamedTuple): @udf(args=VecInputs, returns=VecOutputs) def vec_function_ints( - x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], -) -> npt.NDArray[np.int_]: + x: npt.IntArray, y: npt.IntArray, +) -> npt.IntArray: return x * y @@ -623,16 +621,16 @@ class DFOutputs(typing.NamedTuple): @udf(args=VecInputs, returns=DFOutputs) def vec_function_df( - x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], -) -> Table[pd.DataFrame]: - return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) + x: npt.IntArray, y: npt.IntArray, +) -> Table[pdt.DataFrame]: + return pdt.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) @udf(args=VecInputs, returns=DFOutputs) async def async_vec_function_df( - x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], -) -> Table[pd.DataFrame]: - return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) + x: npt.IntArray, y: npt.IntArray, +) -> Table[pdt.DataFrame]: + return pdt.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) class MaskOutputs(typing.NamedTuple): @@ -641,8 +639,8 @@ class MaskOutputs(typing.NamedTuple): @udf(args=VecInputs, returns=MaskOutputs) def vec_function_ints_masked( - x: Masked[npt.NDArray[np.int_]], y: Masked[npt.NDArray[np.int_]], -) -> Table[Masked[npt.NDArray[np.int_]]]: + x: Masked[npt.IntArray], y: Masked[npt.IntArray], +) -> Table[Masked[npt.IntArray]]: x_data, x_nulls = x y_data, y_nulls = y return Table(Masked(x_data * y_data, x_nulls | y_nulls)) @@ -655,8 +653,8 @@ class MaskOutputs2(typing.NamedTuple): @udf(args=VecInputs, returns=MaskOutputs2) def vec_function_ints_masked2( - x: Masked[npt.NDArray[np.int_]], y: Masked[npt.NDArray[np.int_]], -) -> Table[Masked[npt.NDArray[np.int_]], Masked[npt.NDArray[np.int_]]]: + x: Masked[npt.IntArray], y: Masked[npt.IntArray], +) -> Table[Masked[npt.IntArray], Masked[npt.IntArray]]: x_data, x_nulls = x y_data, y_nulls = y return Table( From a47ff35d69ac98c1d2ed601c9840830faeb69b4c Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 6 Jun 2025 15:58:25 -0500 Subject: [PATCH 06/10] Add metrics --- singlestoredb/functions/ext/asgi.py | 67 +++++++++++++++++++---------- 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index d3ee596d..5fe85ae3 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -40,8 +40,10 @@ import tempfile import textwrap import threading +import time import typing import urllib +import uuid import zipfile import zipimport from types import ModuleType @@ -69,6 +71,7 @@ from ..signature import signature_to_sql from ..typing import Masked from ..typing import Table +from .timer import Timer try: import cloudpickle @@ -613,6 +616,16 @@ def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None: pass +def start_counter() -> float: + """Start a timer and return the start time.""" + return time.perf_counter() + + +def end_counter(start: float) -> float: + """End a timer and return the elapsed time.""" + return time.perf_counter() - start + + class Application(object): """ Create an external function application. @@ -939,6 +952,8 @@ async def __call__( Function to send response information ''' + timer = Timer(id=str(uuid.uuid4()), timestamp=time.time()) + assert scope['type'] == 'http' method = scope['method'] @@ -964,12 +979,13 @@ async def __call__( returns_data_format = func_info['returns_data_format'] data = [] more_body = True - while more_body: - request = await receive() - if request['type'] == 'http.disconnect': - raise RuntimeError('client disconnected') - data.append(request['body']) - more_body = request.get('more_body', False) + with timer('receive_data'): + while more_body: + request = await receive() + if request['type'] == 'http.disconnect': + raise RuntimeError('client disconnected') + data.append(request['body']) + more_body = request.get('more_body', False) data_version = headers.get(b's2-ef-version', b'') input_handler = self.handlers[(content_type, data_version, args_data_format)] @@ -981,17 +997,17 @@ async def __call__( cancel_event = threading.Event() - func_args = [ - cancel_event, - *input_handler['load']( # type: ignore + with timer('parse_input'): + inputs = input_handler['load']( # type: ignore func_info['colspec'], b''.join(data), - ), - ] + ) func_task = asyncio.create_task( - func(*func_args) + func(cancel_event, *inputs) if func_info['is_async'] - else to_thread(lambda: asyncio.run(func(*func_args))), + else to_thread( + lambda: asyncio.run(func(cancel_event, *inputs)), + ), ) disconnect_task = asyncio.create_task( cancel_on_disconnect(receive), @@ -1002,9 +1018,10 @@ async def __call__( all_tasks += [func_task, disconnect_task, timeout_task] - done, pending = await asyncio.wait( - all_tasks, return_when=asyncio.FIRST_COMPLETED, - ) + with timer('function_call'): + done, pending = await asyncio.wait( + all_tasks, return_when=asyncio.FIRST_COMPLETED, + ) cancel_all_tasks(pending) @@ -1024,9 +1041,10 @@ async def __call__( elif task is func_task: result.extend(task.result()) - body = output_handler['dump']( - [x[1] for x in func_info['returns']], *result, # type: ignore - ) + with timer('format_output'): + body = output_handler['dump']( + [x[1] for x in func_info['returns']], *result, # type: ignore + ) await send(output_handler['response']) @@ -1089,9 +1107,14 @@ async def __call__( await send(self.path_not_found_response_dict) # Send body - out = self.body_response_dict.copy() - out['body'] = body - await send(out) + with timer('send_response'): + out = self.body_response_dict.copy() + out['body'] = body + await send(out) + + timer.metadata['function'] = func_name.decode('utf-8') if func_name else '' + timer.finish() + timer.log_metrics() def _create_link( self, From dc60df4ecad090d8b2cf2255f4fa3dad096fcd93 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 6 Jun 2025 16:17:13 -0500 Subject: [PATCH 07/10] Add metrics --- singlestoredb/functions/ext/timer.py | 133 +++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 singlestoredb/functions/ext/timer.py diff --git a/singlestoredb/functions/ext/timer.py b/singlestoredb/functions/ext/timer.py new file mode 100644 index 00000000..f0c94f00 --- /dev/null +++ b/singlestoredb/functions/ext/timer.py @@ -0,0 +1,133 @@ +import json +import time +from typing import Any +from typing import Dict +from typing import List + +from . import utils + +logger = utils.get_logger('singlestoredb.functions.ext.metrics') + + +class RoundedFloatEncoder(json.JSONEncoder): + + def encode(self, obj: Any) -> str: + if isinstance(obj, dict): + return '{' + ', '.join( + f'"{k}": {self._format_value(v)}' + for k, v in obj.items() + ) + '}' + return super().encode(obj) + + def _format_value(self, value: Any) -> str: + if isinstance(value, float): + return f'{value:.2f}' + return json.dumps(value) + + +class Timer: + """ + Timer context manager that supports nested timing using a stack. + + Example + ------- + timer = Timer() + + with timer('total'): + with timer('receive_data'): + time.sleep(0.1) + with timer('parse_input'): + time.sleep(0.2) + with timer('call_function'): + with timer('inner_operation'): + time.sleep(0.05) + time.sleep(0.3) + + print(timer.metrics) + # {'receive_data': 0.1, 'parse_input': 0.2, 'inner_operation': 0.05, + # 'call_function': 0.35, 'total': 0.65} + """ + + def __init__(self, **kwargs: Any) -> None: + """ + Initialize the Timer. + + Parameters + ---------- + metrics : Dict[str, float] + Dictionary to store the timing results + + """ + self.metadata: Dict[str, Any] = kwargs + self.metrics: Dict[str, float] = dict() + self._stack: List[Dict[str, Any]] = [] + self.start_time = time.perf_counter() + + def __call__(self, key: str) -> 'Timer': + """ + Set the key for the next context manager usage. + + Parameters + ---------- + key : str + The key to store the execution time under + + Returns + ------- + Timer + Self, to be used as context manager + + """ + self._current_key = key + return self + + def __enter__(self) -> 'Timer': + """Enter the context manager and start timing.""" + if not hasattr(self, '_current_key'): + raise ValueError( + "No key specified. Use timer('key_name') as context manager.", + ) + + # Push current timing info onto stack + timing_info = { + 'key': self._current_key, + 'start_time': time.perf_counter(), + } + self._stack.append(timing_info) + + # Clear current key for next use + delattr(self, '_current_key') + + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit the context manager and store the elapsed time.""" + if not self._stack: + return + + # Pop the current timing from stack + timing_info = self._stack.pop() + elapsed = time.perf_counter() - timing_info['start_time'] + self.metrics[timing_info['key']] = elapsed + + def finish(self) -> None: + """Finish the current timing context and store the elapsed time.""" + if self._stack: + raise RuntimeError( + 'finish() called without a matching __enter__(). ' + 'Use the context manager instead.', + ) + + self.metrics['total'] = time.perf_counter() - self.start_time + + self.log_metrics() + + def reset(self) -> None: + """Clear all stored times and reset the stack.""" + self.metrics.clear() + self._stack.clear() + + def log_metrics(self) -> None: + if self.metadata.get('function'): + result = dict(type='function_metrics', **self.metadata, **self.metrics) + logger.info(json.dumps(result, cls=RoundedFloatEncoder)) From ec9523f421e565c465eb2bba3883873dcecb6125 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 11 Jun 2025 13:13:48 -0500 Subject: [PATCH 08/10] Add layers of timings --- singlestoredb/functions/ext/asgi.py | 88 ++++++++++++++++------------ singlestoredb/functions/ext/timer.py | 8 +-- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 5fe85ae3..1e9b1a11 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -313,17 +313,19 @@ def build_udf_endpoint( async def do_func( cancel_event: threading.Event, + timer: Timer, row_ids: Sequence[int], rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' out = [] - for row in rows: - cancel_on_event(cancel_event) - if is_async: - out.append(await func(*row)) - else: - out.append(func(*row)) + with timer('call_function'): + for row in rows: + cancel_on_event(cancel_event) + if is_async: + out.append(await func(*row)) + else: + out.append(func(*row)) return row_ids, list(zip(out)) return do_func @@ -357,6 +359,7 @@ def build_vector_udf_endpoint( async def do_func( cancel_event: threading.Event, + timer: Timer, row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], ) -> Tuple[ @@ -367,16 +370,17 @@ async def do_func( row_ids = array_cls(row_ids) # Call the function with `cols` as the function parameters - if cols and cols[0]: - if is_async: - out = await func(*[x if m else x[0] for x, m in zip(cols, masks)]) - else: - out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) - else: - if is_async: - out = await func() + with timer('call_function'): + if cols and cols[0]: + if is_async: + out = await func(*[x if m else x[0] for x, m in zip(cols, masks)]) + else: + out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) else: - out = func() + if is_async: + out = await func() + else: + out = func() cancel_on_event(cancel_event) @@ -420,6 +424,7 @@ def build_tvf_endpoint( async def do_func( cancel_event: threading.Event, + timer: Timer, row_ids: Sequence[int], rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: @@ -427,14 +432,15 @@ async def do_func( out_ids: List[int] = [] out = [] # Call function on each row of data - for i, row in zip(row_ids, rows): - cancel_on_event(cancel_event) - if is_async: - res = await func(*row) - else: - res = func(*row) - out.extend(as_list_of_tuples(res)) - out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) + with timer('call_function'): + for i, row in zip(row_ids, rows): + cancel_on_event(cancel_event) + if is_async: + res = await func(*row) + else: + res = func(*row) + out.extend(as_list_of_tuples(res)) + out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out return do_func @@ -467,6 +473,7 @@ def build_vector_tvf_endpoint( async def do_func( cancel_event: threading.Event, + timer: Timer, row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], ) -> Tuple[ @@ -481,20 +488,23 @@ async def do_func( is_async = asyncio.iscoroutinefunction(func) # Call function on each column of data - if cols and cols[0]: - if is_async: - res = get_dataframe_columns( - await func(*[x if m else x[0] for x, m in zip(cols, masks)]), - ) - else: - res = get_dataframe_columns( - func(*[x if m else x[0] for x, m in zip(cols, masks)]), - ) - else: - if is_async: - res = get_dataframe_columns(await func()) + with timer('call_function'): + if cols and cols[0]: + if is_async: + func_res = await func( + *[x if m else x[0] for x, m in zip(cols, masks)], + ) + else: + func_res = func( + *[x if m else x[0] for x, m in zip(cols, masks)], + ) else: - res = get_dataframe_columns(func()) + if is_async: + func_res = await func() + else: + func_res = func() + + res = get_dataframe_columns(func_res) cancel_on_event(cancel_event) @@ -1003,10 +1013,10 @@ async def __call__( ) func_task = asyncio.create_task( - func(cancel_event, *inputs) + func(cancel_event, timer, *inputs) if func_info['is_async'] else to_thread( - lambda: asyncio.run(func(cancel_event, *inputs)), + lambda: asyncio.run(func(cancel_event, timer, *inputs)), ), ) disconnect_task = asyncio.create_task( @@ -1018,7 +1028,7 @@ async def __call__( all_tasks += [func_task, disconnect_task, timeout_task] - with timer('function_call'): + with timer('function_wrapper'): done, pending = await asyncio.wait( all_tasks, return_when=asyncio.FIRST_COMPLETED, ) diff --git a/singlestoredb/functions/ext/timer.py b/singlestoredb/functions/ext/timer.py index f0c94f00..8f95fa77 100644 --- a/singlestoredb/functions/ext/timer.py +++ b/singlestoredb/functions/ext/timer.py @@ -108,15 +108,13 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # Pop the current timing from stack timing_info = self._stack.pop() elapsed = time.perf_counter() - timing_info['start_time'] - self.metrics[timing_info['key']] = elapsed + self.metrics.setdefault(timing_info['key'], 0) + self.metrics[timing_info['key']] += elapsed def finish(self) -> None: """Finish the current timing context and store the elapsed time.""" if self._stack: - raise RuntimeError( - 'finish() called without a matching __enter__(). ' - 'Use the context manager instead.', - ) + raise RuntimeError('finish() called within a `with` block.') self.metrics['total'] = time.perf_counter() - self.start_time From 8aff597078e1d6d4cafbd8f154991828e35fe58d Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 11 Jun 2025 15:43:47 -0500 Subject: [PATCH 09/10] Add logging --- singlestoredb/functions/ext/asgi.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 1e9b1a11..163d5147 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -26,6 +26,7 @@ import asyncio import contextvars import dataclasses +import datetime import functools import importlib.util import inspect @@ -962,7 +963,14 @@ async def __call__( Function to send response information ''' - timer = Timer(id=str(uuid.uuid4()), timestamp=time.time()) + request_id = str(uuid.uuid4()) + + timer = Timer( + id=request_id, + timestamp=datetime.datetime.now( + datetime.timezone.utc, + ).strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + ) assert scope['type'] == 'http' @@ -978,6 +986,8 @@ async def __call__( func_name = headers.get(b's2-ef-name', b'') func_endpoint = self.endpoints.get(func_name) + timer.metadata['function'] = func_name.decode('utf-8') if func_name else '' + func = None func_info: Dict[str, Any] = {} if func_endpoint is not None: @@ -985,6 +995,17 @@ async def __call__( # Call the endpoint if method == 'POST' and func is not None and path == self.invoke_path: + + logger.info( + json.dumps({ + 'type': 'function_call', + 'id': request_id, + 'name': func_name.decode('utf-8'), + 'content_type': content_type.decode('utf-8'), + 'accepts': accepts.decode('utf-8'), + }), + ) + args_data_format = func_info['args_data_format'] returns_data_format = func_info['returns_data_format'] data = [] @@ -1122,9 +1143,7 @@ async def __call__( out['body'] = body await send(out) - timer.metadata['function'] = func_name.decode('utf-8') if func_name else '' timer.finish() - timer.log_metrics() def _create_link( self, From 432c08ac12c2dc884f381b1ec0411b05469813d2 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 12 Jun 2025 13:31:57 -0500 Subject: [PATCH 10/10] Update github workflows to only run PR checks against the management API if they contain relevant changes --- .github/workflows/code-check.yml | 173 +++++++++++++++++++++++++++++++ .github/workflows/coverage.yml | 10 +- 2 files changed, 178 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/code-check.yml diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml new file mode 100644 index 00000000..c52aca3a --- /dev/null +++ b/.github/workflows/code-check.yml @@ -0,0 +1,173 @@ +name: Coverage tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +jobs: + test-coverage: + runs-on: ubuntu-latest + environment: Base + + services: + singlestore: + image: ghcr.io/singlestore-labs/singlestoredb-dev:latest + ports: + - 3307:3306 + - 8081:8080 + - 9081:9081 + env: + SINGLESTORE_LICENSE: ${{ secrets.SINGLESTORE_LICENSE }} + ROOT_PASSWORD: "root" + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r test-requirements.txt + + - name: Install SingleStore package + run: | + pip install . + + - name: Check for changes in monitored directories + id: check-changes + run: | + # Define directories to monitor (space-separated) + MONITORED_DIRS="singlestoredb/management singlestoredb/fusion" + + # Determine the base commit to compare against + if [ "${{ github.event_name }}" == "pull_request" ]; then + # For PRs, compare against the target branch (usually master/main) + BASE_COMMIT="origin/${{ github.event.pull_request.base.ref }}" + echo "Pull Request: Comparing against $BASE_COMMIT" + elif [ "${{ github.ref_name }}" == "main" ] || [ "${{ github.ref_name }}" == "master" ]; then + # For pushes to main/master, compare against previous commit + BASE_COMMIT="HEAD~1" + echo "Push to main/master: Comparing against $BASE_COMMIT" + else: + # For pushes to other branches, compare against master/main + if git rev-parse --verify origin/main >/dev/null 2>&1; then + BASE_COMMIT="origin/main" + echo "Push to branch: Comparing against origin/main" + elif git rev-parse --verify origin/master >/dev/null 2>&1; then + BASE_COMMIT="origin/master" + echo "Push to branch: Comparing against origin/master" + else + # Fallback to previous commit if master/main not found + BASE_COMMIT="HEAD~1" + echo "Fallback: Comparing against HEAD~1" + fi + fi + + echo "Checking for changes in: $MONITORED_DIRS" + echo "Comparing against: $BASE_COMMIT" + + # Check for any changes in monitored directories + CHANGES_FOUND=false + CHANGED_DIRS="" + + for DIR in $MONITORED_DIRS; do + if [ -d "$DIR" ]; then + CHANGED_FILES=$(git diff --name-only $BASE_COMMIT HEAD -- "$DIR" || true) + if [ -n "$CHANGED_FILES" ]; then + echo "✅ Changes detected in: $DIR" + echo "Files changed:" + echo "$CHANGED_FILES" | sed 's/^/ - /' + CHANGES_FOUND=true + if [ -z "$CHANGED_DIRS" ]; then + CHANGED_DIRS="$DIR" + else + CHANGED_DIRS="$CHANGED_DIRS,$DIR" + fi + else + echo "❌ No changes in: $DIR" + fi + else + echo "⚠️ Directory not found: $DIR" + fi + done + + # Set outputs + if [ "$CHANGES_FOUND" = true ]; then + echo "changes-detected=true" >> $GITHUB_OUTPUT + echo "changed-directories=$CHANGED_DIRS" >> $GITHUB_OUTPUT + echo "" + echo "🎯 RESULT: Changes detected in monitored directories" + else + echo "changes-detected=false" >> $GITHUB_OUTPUT + echo "changed-directories=" >> $GITHUB_OUTPUT + echo "" + echo "🎯 RESULT: No changes in monitored directories" + fi + + - name: Run MySQL protocol tests (with management API) + if: steps.check-changes.outputs.changes-detected == 'true' + run: | + pytest -v --cov=singlestoredb --pyargs singlestoredb.tests + env: + COVERAGE_FILE: "coverage-mysql.cov" + SINGLESTOREDB_URL: "root:root@127.0.0.1:3307" + SINGLESTOREDB_PURE_PYTHON: 0 + SINGLESTORE_LICENSE: ${{ secrets.SINGLESTORE_LICENSE }} + SINGLESTOREDB_MANAGEMENT_TOKEN: ${{ secrets.CLUSTER_API_KEY }} + SINGLESTOREDB_FUSION_ENABLE_HIDDEN: "1" + + - name: Run MySQL protocol tests (without management API) + if: steps.check-changes.outputs.changes-detected == 'false' + run: | + pytest -v -m 'not management' --cov=singlestoredb --pyargs singlestoredb.tests + env: + COVERAGE_FILE: "coverage-mysql.cov" + SINGLESTOREDB_URL: "root:root@127.0.0.1:3307" + SINGLESTOREDB_PURE_PYTHON: 0 + SINGLESTORE_LICENSE: ${{ secrets.SINGLESTORE_LICENSE }} + SINGLESTOREDB_MANAGEMENT_TOKEN: ${{ secrets.CLUSTER_API_KEY }} + SINGLESTOREDB_FUSION_ENABLE_HIDDEN: "1" + + - name: Run MySQL protocol tests (pure Python) + run: | + pytest -v -m 'not management' --cov=singlestoredb --pyargs singlestoredb.tests + env: + COVERAGE_FILE: "coverage-mysql-py.cov" + SINGLESTOREDB_URL: "root:root@127.0.0.1:3307" + SINGLESTOREDB_PURE_PYTHON: 1 + SINGLESTORE_LICENSE: ${{ secrets.SINGLESTORE_LICENSE }} + SINGLESTOREDB_MANAGEMENT_TOKEN: ${{ secrets.CLUSTER_API_KEY }} + SINGLESTOREDB_FUSION_ENABLE_HIDDEN: "1" + + - name: Run HTTP protocol tests + run: | + pytest -v -m 'not management' --cov=singlestoredb --pyargs singlestoredb.tests + env: + COVERAGE_FILE: "coverage-http.cov" + SINGLESTOREDB_URL: "http://root:root@127.0.0.1:9081" + SINGLESTORE_LICENSE: ${{ secrets.SINGLESTORE_LICENSE }} + SINGLESTOREDB_MANAGEMENT_TOKEN: ${{ secrets.CLUSTER_API_KEY }} + # Can not change databases using HTTP API. The URL below will be + # used to create the database and the generated database name will + # be applied to the above URL. + SINGLESTOREDB_INIT_DB_URL: "root:root@127.0.0.1:3307" + SINGLESTOREDB_FUSION_ENABLE_HIDDEN: "1" + + - name: Generate report + run: | + coverage combine coverage-mysql.cov coverage-http.cov coverage-mysql-py.cov + coverage report + coverage xml + coverage html diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index aaee0d55..bf60140c 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -1,10 +1,8 @@ name: Coverage tests on: - push: - branches: [ main ] - pull_request: - branches: [ main ] + schedule: + - cron: "0 1 * * *" workflow_dispatch: jobs: @@ -24,7 +22,9 @@ jobs: ROOT_PASSWORD: "root" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v4