Skip to content

feat(arrow): Allow record batches output from read_sql #819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import importlib
import urllib.parse

from collections.abc import Iterator
from importlib.metadata import version
from pathlib import Path
from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar
Expand Down Expand Up @@ -177,6 +177,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pd.DataFrame: ...


Expand All @@ -192,6 +193,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pd.DataFrame: ...


Expand All @@ -207,6 +209,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pa.Table: ...


Expand All @@ -222,6 +225,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> mpd.DataFrame: ...


Expand All @@ -237,6 +241,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> dd.DataFrame: ...


Expand All @@ -252,6 +257,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pl.DataFrame: ...


Expand All @@ -260,7 +266,7 @@ def read_sql(
query: list[str] | str,
*,
return_type: Literal[
"pandas", "polars", "arrow", "modin", "dask"
"pandas", "polars", "arrow", "modin", "dask", "arrow_record_batches"
] = "pandas",
protocol: Protocol | None = None,
partition_on: str | None = None,
Expand All @@ -269,18 +275,20 @@ def read_sql(
index_col: str | None = None,
strategy: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
**kwargs

) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader:
"""
Run the SQL query, download the data from database into a dataframe.

Parameters
==========
conn
the connection string, or dict of connection string mapping for federated query.
the connection string, or dict of connection string mapping for a federated query.
query
a SQL query or a list of SQL queries.
return_type
the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)".
the return type of this function; one of "arrow(2)", "arrow_record_batches", "pandas", "modin", "dask" or "polars(2)".
protocol
backend-specific transfer protocol directive; defaults to 'binary' (except for redshift
connection strings, where 'cursor' will be used instead).
Expand Down Expand Up @@ -403,31 +411,59 @@ def read_sql(
dd = try_import_module("dask.dataframe")
df = dd.from_pandas(df, npartitions=1)

elif return_type in {"arrow", "polars"}:
elif return_type in {"arrow", "polars", "arrow_record_batches"}:
try_import_module("pyarrow")

record_batch_size = int(kwargs.get("record_batch_size", 10000))
result = _read_sql(
conn,
"arrow",
"arrow_record_batches",
queries=queries,
protocol=protocol,
partition_query=partition_query,
pre_execution_queries=pre_execution_queries,
record_batch_size=record_batch_size
)
df = reconstruct_arrow(result)
if return_type in {"polars"}:
pl = try_import_module("polars")
try:
df = pl.from_arrow(df)
except AttributeError:
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
df = pl.DataFrame.from_arrow(df)

if return_type == "arrow_record_batches":
df = reconstruct_arrow_rb(result)
else:
df = reconstruct_arrow(result)
if return_type in {"polars"}:
pl = try_import_module("polars")
try:
df = pl.from_arrow(df)
except AttributeError:
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
df = pl.DataFrame.from_arrow(df)
else:
raise ValueError(return_type)

return df


def reconstruct_arrow_rb(results) -> pa.RecordBatchReader:
Copy link
Author

@chitralverma chitralverma Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns a pyarrow RecordBatchReader instead of an iterator/ generator of RecordBatch. I guess this will be useful for users who want to get the pyarrow Schema since RecordBatchReader has it.

import pyarrow as pa

# Get Schema
names, chunk_ptrs_list = results.schema_ptr()
for chunk_ptrs in chunk_ptrs_list:
arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs]
empty_rb = pa.RecordBatch.from_arrays(arrays, names)

schema = empty_rb.schema

def generate_batches(iterator) -> Iterator[pa.RecordBatch]:
for rb_ptrs in iterator:
names, chunk_ptrs_list = rb_ptrs.to_ptrs()
for chunk_ptrs in chunk_ptrs_list:
yield pa.RecordBatch.from_arrays(
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names
)

return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results))


def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
import pyarrow as pa

Expand Down
4 changes: 3 additions & 1 deletion connectorx-python/connectorx/connectorx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ def read_sql(
queries: list[str] | None,
partition_query: dict[str, Any] | None,
pre_execution_queries: list[str] | None,
**kwargs
) -> _DataframeInfos: ...
@overload
def read_sql(
conn: str,
return_type: Literal["arrow"],
return_type: Literal["arrow", "arrow_record_batches"],
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
pre_execution_queries: list[str] | None,
**kwargs
) -> _ArrowInfos: ...
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...
Expand Down
90 changes: 90 additions & 0 deletions connectorx-python/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,76 @@ use connectorx::{prelude::*, sql::CXQuery};
use fehler::throws;
use libc::uintptr_t;
use pyo3::prelude::*;
use pyo3::pyclass;
use pyo3::{PyAny, Python};
use std::convert::TryFrom;
use std::sync::Arc;

/// Python-exposed RecordBatch wrapper
#[pyclass]
pub struct PyRecordBatch(RecordBatch);

/// Python-exposed iterator over RecordBatches
#[pyclass(unsendable, module = "connectorx")]
pub struct PyRecordBatchIterator(Box<dyn RecordBatchIterator>);

#[pymethods]
impl PyRecordBatch {
pub fn num_rows(&self) -> usize {
self.0.num_rows()
}

pub fn num_columns(&self) -> usize {
self.0.num_columns()
}

#[throws(ConnectorXPythonError)]
pub fn to_ptrs<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
let ptrs = py.allow_threads(
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
let rbs = vec![self.0.clone()];
Copy link
Author

@chitralverma chitralverma Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this okay or do you suggest any workarounds?

# doesn't work without `.clone()`, breaks with the following 

cannot move out of `self` which is behind a shared reference
move occurs because `self.0` has type `arrow::array::RecordBatch`, which does not implement the `Copy` trait

Ok(to_ptrs(rbs))
},
)?;
let obj: PyObject = ptrs.into_py(py);
obj.into_bound(py)
}
}

#[pymethods]
impl PyRecordBatchIterator {
#[throws(ConnectorXPythonError)]
fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
let (rb, _) = self.0.get_schema();
let ptrs = py.allow_threads(
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
let rbs = vec![rb];
Ok(to_ptrs(rbs))
},
)?;
let obj: PyObject = ptrs.into_py(py);
obj.into_bound(py)
}
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __next__<'py>(
mut slf: PyRefMut<'py, Self>,
py: Python<'py>,
) -> PyResult<Option<Py<PyRecordBatch>>> {
match slf.0.next_batch() {
Some(rb) => {
let wrapped = PyRecordBatch(rb);
let py_obj = Py::new(py, wrapped)?;
Ok(Some(py_obj))
}

None => Ok(None),
}
}
}

#[throws(ConnectorXPythonError)]
pub fn write_arrow<'py>(
py: Python<'py>,
Expand All @@ -28,6 +94,30 @@ pub fn write_arrow<'py>(
obj.into_bound(py)
}

#[throws(ConnectorXPythonError)]
pub fn get_arrow_rb_iter<'py>(
py: Python<'py>,
source_conn: &SourceConn,
origin_query: Option<String>,
queries: &[CXQuery<String>],
pre_execution_queries: Option<&[String]>,
batch_size: usize,
) -> Bound<'py, PyAny> {
let mut arrow_iter: Box<dyn RecordBatchIterator> = new_record_batch_iter(
source_conn,
origin_query,
queries,
batch_size,
pre_execution_queries,
);

arrow_iter.prepare();
let py_rb_iter = PyRecordBatchIterator(arrow_iter);

let obj: PyObject = py_rb_iter.into_py(py);
obj.into_bound(py)
}

pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>) {
if rbs.is_empty() {
return (vec![], vec![]);
Expand Down
18 changes: 18 additions & 0 deletions connectorx-python/src/cx_read_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pyo3::prelude::*;
use pyo3::{exceptions::PyValueError, PyResult};

use crate::errors::ConnectorXPythonError;
use pyo3::types::PyDict;

#[derive(FromPyObject)]
#[pyo3(from_item_all)]
Expand Down Expand Up @@ -39,6 +40,7 @@ pub fn read_sql<'py>(
queries: Option<Vec<String>>,
partition_query: Option<PyPartitionQuery>,
pre_execution_queries: Option<Vec<String>>,
kwargs: Option<&Bound<PyDict>>,
) -> PyResult<Bound<'py, PyAny>> {
let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?;
let (queries, origin_query) = match (queries, partition_query) {
Expand Down Expand Up @@ -72,6 +74,22 @@ pub fn read_sql<'py>(
&queries,
pre_execution_queries.as_deref(),
)?),
"arrow_record_batches" => {
let batch_size = kwargs
.and_then(|dict| dict.get_item("record_batch_size").ok().flatten())
.and_then(|obj| obj.extract::<usize>().ok())
.unwrap_or(10000);

Ok(crate::arrow::get_arrow_rb_iter(
py,
&source_conn,
origin_query,
&queries,
pre_execution_queries.as_deref(),
batch_size,
)?)
}

_ => Err(PyValueError::new_err(format!(
"return type should be 'pandas' or 'arrow', got '{}'",
return_type
Expand Down
7 changes: 6 additions & 1 deletion connectorx-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::constants::J4RS_BASE_PATH;
use ::connectorx::{fed_dispatcher::run, partition::partition, source_router::parse_source};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::{wrap_pyfunction, PyResult};
use std::collections::HashMap;
use std::env;
Expand Down Expand Up @@ -35,11 +36,13 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(partition_sql))?;
m.add_wrapped(wrap_pyfunction!(get_meta))?;
m.add_class::<pandas::PandasBlockInfo>()?;
m.add_class::<arrow::PyRecordBatch>()?;
m.add_class::<arrow::PyRecordBatchIterator>()?;
Ok(())
}

#[pyfunction]
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))]
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None, *, **kwargs))]
pub fn read_sql<'py>(
py: Python<'py>,
conn: &str,
Expand All @@ -48,6 +51,7 @@ pub fn read_sql<'py>(
queries: Option<Vec<String>>,
partition_query: Option<cx_read_sql::PyPartitionQuery>,
pre_execution_queries: Option<Vec<String>>,
kwargs: Option<&Bound<PyDict>>,
) -> PyResult<Bound<'py, PyAny>> {
cx_read_sql::read_sql(
py,
Expand All @@ -57,6 +61,7 @@ pub fn read_sql<'py>(
queries,
partition_query,
pre_execution_queries,
kwargs,
)
}

Expand Down
2 changes: 1 addition & 1 deletion connectorx/src/arrow_batch_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ where
type Item = RecordBatch;
/// NOTE: not thread safe
fn next(&mut self) -> Option<Self::Item> {
self.dst.record_batch().unwrap()
self.dst.record_batch().ok().flatten()
}
}

Expand Down
2 changes: 1 addition & 1 deletion connectorx/src/destinations/arrowstream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl ArrowPartitionWriter {
.map(|(builder, &dt)| Realize::<FFinishBuilder>::realize(dt)?(builder))
.collect::<std::result::Result<Vec<_>, crate::errors::ConnectorXError>>()?;
let rb = RecordBatch::try_new(Arc::clone(&self.arrow_schema), columns)?;
self.sender.as_ref().unwrap().send(rb).unwrap();
self.sender.as_ref().and_then(|s| s.send(rb).ok());

self.current_row = 0;
self.current_col = 0;
Expand Down