From 016adc0b027e55af4a3aa33375ed905c06979050 Mon Sep 17 00:00:00 2001 From: chitralverma Date: Fri, 20 Jun 2025 14:36:59 +0530 Subject: [PATCH 1/3] Allow record batches --- connectorx-python/connectorx/__init__.py | 66 +++++++++++--- connectorx-python/connectorx/connectorx.pyi | 4 +- connectorx-python/src/arrow.rs | 91 +++++++++++++++++++ connectorx-python/src/cx_read_sql.rs | 18 ++++ connectorx-python/src/lib.rs | 7 +- connectorx/src/arrow_batch_iter.rs | 2 +- .../src/destinations/arrowstream/mod.rs | 2 +- 7 files changed, 171 insertions(+), 19 deletions(-) diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 643a804678..acb547c4fa 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -2,7 +2,7 @@ import importlib import urllib.parse - +from collections.abc import Iterator from importlib.metadata import version from pathlib import Path from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar @@ -177,6 +177,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pd.DataFrame: ... @@ -192,6 +193,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pd.DataFrame: ... @@ -207,6 +209,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pa.Table: ... @@ -222,6 +225,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> mpd.DataFrame: ... @@ -237,6 +241,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> dd.DataFrame: ... @@ -252,6 +257,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pl.DataFrame: ... @@ -260,7 +266,7 @@ def read_sql( query: list[str] | str, *, return_type: Literal[ - "pandas", "polars", "arrow", "modin", "dask" + "pandas", "polars", "arrow", "modin", "dask", "arrow_record_batches" ] = "pandas", protocol: Protocol | None = None, partition_on: str | None = None, @@ -269,18 +275,20 @@ def read_sql( index_col: str | None = None, strategy: str | None = None, pre_execution_query: list[str] | str | None = None, -) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table: + **kwargs + +) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader: """ Run the SQL query, download the data from database into a dataframe. Parameters ========== conn - the connection string, or dict of connection string mapping for federated query. + the connection string, or dict of connection string mapping for a federated query. query a SQL query or a list of SQL queries. return_type - the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)". + the return type of this function; one of "arrow(2)", "arrow_record_batches", "pandas", "modin", "dask" or "polars(2)". protocol backend-specific transfer protocol directive; defaults to 'binary' (except for redshift connection strings, where 'cursor' will be used instead). @@ -403,31 +411,59 @@ def read_sql( dd = try_import_module("dask.dataframe") df = dd.from_pandas(df, npartitions=1) - elif return_type in {"arrow", "polars"}: + elif return_type in {"arrow", "polars", "arrow_record_batches"}: try_import_module("pyarrow") + record_batch_size = int(kwargs.get("record_batch_size", 10000)) result = _read_sql( conn, - "arrow", + "arrow_record_batches", queries=queries, protocol=protocol, partition_query=partition_query, pre_execution_queries=pre_execution_queries, + record_batch_size=record_batch_size ) - df = reconstruct_arrow(result) - if return_type in {"polars"}: - pl = try_import_module("polars") - try: - df = pl.from_arrow(df) - except AttributeError: - # previous polars api (< 0.8.*) was pl.DataFrame.from_arrow - df = pl.DataFrame.from_arrow(df) + + if return_type == "arrow_record_batches": + df = reconstruct_arrow_rb(result) + else: + df = reconstruct_arrow(result) + if return_type in {"polars"}: + pl = try_import_module("polars") + try: + df = pl.from_arrow(df) + except AttributeError: + # previous polars api (< 0.8.*) was pl.DataFrame.from_arrow + df = pl.DataFrame.from_arrow(df) else: raise ValueError(return_type) return df +def reconstruct_arrow_rb(results) -> Iterator[pa.RecordBatch]: + import pyarrow as pa + + # Get Schema + names, chunk_ptrs_list = results.schema_ptr() + for chunk_ptrs in chunk_ptrs_list: + arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs] + empty_rb = pa.RecordBatch.from_arrays(arrays, names) + + schema = empty_rb.schema + + def generate_batches(iterator) -> Iterator[pa.RecordBatch]: + for rb_ptrs in iterator: + names, chunk_ptrs_list = rb_ptrs.to_ptrs() + for chunk_ptrs in chunk_ptrs_list: + yield pa.RecordBatch.from_arrays( + [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names + ) + + return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results)) + + def reconstruct_arrow(result: _ArrowInfos) -> pa.Table: import pyarrow as pa diff --git a/connectorx-python/connectorx/connectorx.pyi b/connectorx-python/connectorx/connectorx.pyi index f63709c079..f34c58c160 100644 --- a/connectorx-python/connectorx/connectorx.pyi +++ b/connectorx-python/connectorx/connectorx.pyi @@ -26,15 +26,17 @@ def read_sql( queries: list[str] | None, partition_query: dict[str, Any] | None, pre_execution_queries: list[str] | None, + **kwargs ) -> _DataframeInfos: ... @overload def read_sql( conn: str, - return_type: Literal["arrow"], + return_type: Literal["arrow", "arrow_record_batches"], protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, pre_execution_queries: list[str] | None, + **kwargs ) -> _ArrowInfos: ... def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ... def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ... diff --git a/connectorx-python/src/arrow.rs b/connectorx-python/src/arrow.rs index f506f60dd8..e2d02cfa6b 100644 --- a/connectorx-python/src/arrow.rs +++ b/connectorx-python/src/arrow.rs @@ -5,10 +5,77 @@ use connectorx::{prelude::*, sql::CXQuery}; use fehler::throws; use libc::uintptr_t; use pyo3::prelude::*; +use pyo3::pyclass; use pyo3::{PyAny, Python}; use std::convert::TryFrom; use std::sync::Arc; +/// Python-exposed RecordBatch wrapper +#[pyclass] +pub struct PyRecordBatch(RecordBatch); + +/// Python-exposed iterator over RecordBatches +#[pyclass(unsendable, module = "connectorx")] +pub struct PyRecordBatchIterator(Box); + +#[pymethods] +impl PyRecordBatch { + pub fn num_rows(&self) -> usize { + self.0.num_rows() + } + + pub fn num_columns(&self) -> usize { + self.0.num_columns() + } + + #[throws(ConnectorXPythonError)] + pub fn to_ptrs<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { + let ptrs = py.allow_threads( + || -> Result<(Vec, Vec>), ConnectorXPythonError> { + let rbs = vec![self.0.clone()]; + Ok(to_ptrs(rbs)) + }, + )?; + let obj: PyObject = ptrs.into_py(py); + obj.into_bound(py) + } +} + +#[pymethods] +impl PyRecordBatchIterator { + + #[throws(ConnectorXPythonError)] + fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { + let (rb, _) = self.0.get_schema(); + let ptrs = py.allow_threads( + || -> Result<(Vec, Vec>), ConnectorXPythonError> { + let rbs = vec![rb]; + Ok(to_ptrs(rbs)) + }, + )?; + let obj: PyObject = ptrs.into_py(py); + obj.into_bound(py) + } + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__<'py>( + mut slf: PyRefMut<'py, Self>, + py: Python<'py>, + ) -> PyResult>> { + match slf.0.next_batch() { + Some(rb) => { + let wrapped = PyRecordBatch(rb); + let py_obj = Py::new(py, wrapped)?; + Ok(Some(py_obj)) + } + + None => Ok(None), + } + } +} + #[throws(ConnectorXPythonError)] pub fn write_arrow<'py>( py: Python<'py>, @@ -28,6 +95,30 @@ pub fn write_arrow<'py>( obj.into_bound(py) } +#[throws(ConnectorXPythonError)] +pub fn get_arrow_rb_iter<'py>( + py: Python<'py>, + source_conn: &SourceConn, + origin_query: Option, + queries: &[CXQuery], + pre_execution_queries: Option<&[String]>, + batch_size: usize, +) -> Bound<'py, PyAny> { + let mut arrow_iter: Box = new_record_batch_iter( + source_conn, + origin_query, + queries, + batch_size, + pre_execution_queries, + ); + + arrow_iter.prepare(); + let py_rb_iter = PyRecordBatchIterator(arrow_iter); + + let obj: PyObject = py_rb_iter.into_py(py); + obj.into_bound(py) +} + pub fn to_ptrs(rbs: Vec) -> (Vec, Vec>) { if rbs.is_empty() { return (vec![], vec![]); diff --git a/connectorx-python/src/cx_read_sql.rs b/connectorx-python/src/cx_read_sql.rs index a95981716d..080fbcbbd6 100644 --- a/connectorx-python/src/cx_read_sql.rs +++ b/connectorx-python/src/cx_read_sql.rs @@ -8,6 +8,7 @@ use pyo3::prelude::*; use pyo3::{exceptions::PyValueError, PyResult}; use crate::errors::ConnectorXPythonError; +use pyo3::types::PyDict; #[derive(FromPyObject)] #[pyo3(from_item_all)] @@ -39,6 +40,7 @@ pub fn read_sql<'py>( queries: Option>, partition_query: Option, pre_execution_queries: Option>, + kwargs: Option<&Bound>, ) -> PyResult> { let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?; let (queries, origin_query) = match (queries, partition_query) { @@ -72,6 +74,22 @@ pub fn read_sql<'py>( &queries, pre_execution_queries.as_deref(), )?), + "arrow_record_batches" => { + let batch_size = kwargs + .and_then(|dict| dict.get_item("record_batch_size").ok().flatten()) + .and_then(|obj| obj.extract::().ok()) + .unwrap_or(10000); + + Ok(crate::arrow::get_arrow_rb_iter( + py, + &source_conn, + origin_query, + &queries, + pre_execution_queries.as_deref(), + batch_size, + )?) + } + _ => Err(PyValueError::new_err(format!( "return type should be 'pandas' or 'arrow', got '{}'", return_type diff --git a/connectorx-python/src/lib.rs b/connectorx-python/src/lib.rs index b4f573b35a..d6286dddfc 100644 --- a/connectorx-python/src/lib.rs +++ b/connectorx-python/src/lib.rs @@ -8,6 +8,7 @@ use crate::constants::J4RS_BASE_PATH; use ::connectorx::{fed_dispatcher::run, partition::partition, source_router::parse_source}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; +use pyo3::types::PyDict; use pyo3::{wrap_pyfunction, PyResult}; use std::collections::HashMap; use std::env; @@ -35,11 +36,13 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(partition_sql))?; m.add_wrapped(wrap_pyfunction!(get_meta))?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } #[pyfunction] -#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))] +#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None, *, **kwargs))] pub fn read_sql<'py>( py: Python<'py>, conn: &str, @@ -48,6 +51,7 @@ pub fn read_sql<'py>( queries: Option>, partition_query: Option, pre_execution_queries: Option>, + kwargs: Option<&Bound>, ) -> PyResult> { cx_read_sql::read_sql( py, @@ -57,6 +61,7 @@ pub fn read_sql<'py>( queries, partition_query, pre_execution_queries, + kwargs, ) } diff --git a/connectorx/src/arrow_batch_iter.rs b/connectorx/src/arrow_batch_iter.rs index 1794a96161..557593aeee 100644 --- a/connectorx/src/arrow_batch_iter.rs +++ b/connectorx/src/arrow_batch_iter.rs @@ -149,7 +149,7 @@ where type Item = RecordBatch; /// NOTE: not thread safe fn next(&mut self) -> Option { - self.dst.record_batch().unwrap() + self.dst.record_batch().ok().flatten() } } diff --git a/connectorx/src/destinations/arrowstream/mod.rs b/connectorx/src/destinations/arrowstream/mod.rs index d8487a268c..089b927bb5 100644 --- a/connectorx/src/destinations/arrowstream/mod.rs +++ b/connectorx/src/destinations/arrowstream/mod.rs @@ -221,7 +221,7 @@ impl ArrowPartitionWriter { .map(|(builder, &dt)| Realize::::realize(dt)?(builder)) .collect::, crate::errors::ConnectorXError>>()?; let rb = RecordBatch::try_new(Arc::clone(&self.arrow_schema), columns)?; - self.sender.as_ref().unwrap().send(rb).unwrap(); + self.sender.as_ref().and_then(|s| s.send(rb).ok()); self.current_row = 0; self.current_col = 0; From 03c5f784881e528f787aa953f6572348d60f1773 Mon Sep 17 00:00:00 2001 From: chitralverma Date: Fri, 20 Jun 2025 19:27:09 +0530 Subject: [PATCH 2/3] fix type --- connectorx-python/connectorx/__init__.py | 2 +- connectorx-python/src/arrow.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index acb547c4fa..b422f28511 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -442,7 +442,7 @@ def read_sql( return df -def reconstruct_arrow_rb(results) -> Iterator[pa.RecordBatch]: +def reconstruct_arrow_rb(results) -> pa.RecordBatchReader: import pyarrow as pa # Get Schema diff --git a/connectorx-python/src/arrow.rs b/connectorx-python/src/arrow.rs index e2d02cfa6b..76ec753399 100644 --- a/connectorx-python/src/arrow.rs +++ b/connectorx-python/src/arrow.rs @@ -43,7 +43,6 @@ impl PyRecordBatch { #[pymethods] impl PyRecordBatchIterator { - #[throws(ConnectorXPythonError)] fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { let (rb, _) = self.0.get_schema(); From 0f1ba57c2de6e651c8d31949f08dd648cea723e1 Mon Sep 17 00:00:00 2001 From: chitralverma Date: Fri, 20 Jun 2025 20:40:10 +0530 Subject: [PATCH 3/3] fix: make RecordBatchIterator Send to support multi-threaded consumers like DuckDB --- connectorx-python/src/arrow.rs | 2 +- connectorx/src/arrow_batch_iter.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/connectorx-python/src/arrow.rs b/connectorx-python/src/arrow.rs index 76ec753399..247387d0ad 100644 --- a/connectorx-python/src/arrow.rs +++ b/connectorx-python/src/arrow.rs @@ -15,7 +15,7 @@ use std::sync::Arc; pub struct PyRecordBatch(RecordBatch); /// Python-exposed iterator over RecordBatches -#[pyclass(unsendable, module = "connectorx")] +#[pyclass(module = "connectorx")] pub struct PyRecordBatchIterator(Box); #[pymethods] diff --git a/connectorx/src/arrow_batch_iter.rs b/connectorx/src/arrow_batch_iter.rs index 557593aeee..5ab4784449 100644 --- a/connectorx/src/arrow_batch_iter.rs +++ b/connectorx/src/arrow_batch_iter.rs @@ -153,7 +153,7 @@ where } } -pub trait RecordBatchIterator { +pub trait RecordBatchIterator: Send { fn get_schema(&self) -> (RecordBatch, &[String]); fn prepare(&mut self); fn next_batch(&mut self) -> Option; @@ -167,7 +167,7 @@ where TSD = ArrowStreamTypeSystem, S = S, D = ArrowStreamDestination, - >, + >+ std::marker::Send, { fn get_schema(&self) -> (RecordBatch, &[String]) { (self.dst.empty_batch(), self.dst.names())