-
Notifications
You must be signed in to change notification settings - Fork 186
feat(arrow): Allow record batches output from read_sql #819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
016adc0
03c5f78
0f1ba57
1b23062
c6836d8
4213afb
bc69438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import importlib | ||
import urllib.parse | ||
|
||
from collections.abc import Iterator | ||
from importlib.metadata import version | ||
from pathlib import Path | ||
from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar | ||
|
@@ -177,6 +177,7 @@ def read_sql( | |
partition_num: int | None = None, | ||
index_col: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
**kwargs | ||
) -> pd.DataFrame: ... | ||
|
||
|
||
|
@@ -192,6 +193,7 @@ def read_sql( | |
partition_num: int | None = None, | ||
index_col: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
**kwargs | ||
) -> pd.DataFrame: ... | ||
|
||
|
||
|
@@ -207,6 +209,7 @@ def read_sql( | |
partition_num: int | None = None, | ||
index_col: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
**kwargs | ||
) -> pa.Table: ... | ||
|
||
|
||
|
@@ -222,6 +225,7 @@ def read_sql( | |
partition_num: int | None = None, | ||
index_col: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
**kwargs | ||
) -> mpd.DataFrame: ... | ||
|
||
|
||
|
@@ -237,6 +241,7 @@ def read_sql( | |
partition_num: int | None = None, | ||
index_col: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
**kwargs | ||
) -> dd.DataFrame: ... | ||
|
||
|
||
|
@@ -252,6 +257,7 @@ def read_sql( | |
partition_num: int | None = None, | ||
index_col: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
**kwargs | ||
) -> pl.DataFrame: ... | ||
|
||
|
||
|
@@ -260,7 +266,7 @@ def read_sql( | |
query: list[str] | str, | ||
*, | ||
return_type: Literal[ | ||
"pandas", "polars", "arrow", "modin", "dask" | ||
"pandas", "polars", "arrow", "modin", "dask", "arrow_record_batches" | ||
] = "pandas", | ||
protocol: Protocol | None = None, | ||
partition_on: str | None = None, | ||
|
@@ -269,18 +275,20 @@ def read_sql( | |
index_col: str | None = None, | ||
strategy: str | None = None, | ||
pre_execution_query: list[str] | str | None = None, | ||
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table: | ||
**kwargs | ||
|
||
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader: | ||
""" | ||
Run the SQL query, download the data from database into a dataframe. | ||
|
||
Parameters | ||
========== | ||
conn | ||
the connection string, or dict of connection string mapping for federated query. | ||
the connection string, or dict of connection string mapping for a federated query. | ||
query | ||
a SQL query or a list of SQL queries. | ||
return_type | ||
the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)". | ||
the return type of this function; one of "arrow(2)", "arrow_record_batches", "pandas", "modin", "dask" or "polars(2)". | ||
protocol | ||
backend-specific transfer protocol directive; defaults to 'binary' (except for redshift | ||
connection strings, where 'cursor' will be used instead). | ||
|
@@ -403,31 +411,59 @@ def read_sql( | |
dd = try_import_module("dask.dataframe") | ||
df = dd.from_pandas(df, npartitions=1) | ||
|
||
elif return_type in {"arrow", "polars"}: | ||
elif return_type in {"arrow", "polars", "arrow_record_batches"}: | ||
wangxiaoying marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try_import_module("pyarrow") | ||
|
||
record_batch_size = int(kwargs.get("record_batch_size", 10000)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
result = _read_sql( | ||
conn, | ||
"arrow", | ||
"arrow_record_batches", | ||
queries=queries, | ||
protocol=protocol, | ||
partition_query=partition_query, | ||
pre_execution_queries=pre_execution_queries, | ||
record_batch_size=record_batch_size | ||
) | ||
df = reconstruct_arrow(result) | ||
if return_type in {"polars"}: | ||
pl = try_import_module("polars") | ||
try: | ||
df = pl.from_arrow(df) | ||
except AttributeError: | ||
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow | ||
df = pl.DataFrame.from_arrow(df) | ||
|
||
if return_type == "arrow_record_batches": | ||
df = reconstruct_arrow_rb(result) | ||
else: | ||
df = reconstruct_arrow(result) | ||
if return_type in {"polars"}: | ||
pl = try_import_module("polars") | ||
try: | ||
df = pl.from_arrow(df) | ||
except AttributeError: | ||
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow | ||
df = pl.DataFrame.from_arrow(df) | ||
else: | ||
raise ValueError(return_type) | ||
|
||
return df | ||
|
||
|
||
def reconstruct_arrow_rb(results) -> Iterator[pa.RecordBatch]: | ||
import pyarrow as pa | ||
|
||
# Get Schema | ||
names, chunk_ptrs_list = results.schema_ptr() | ||
for chunk_ptrs in chunk_ptrs_list: | ||
arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs] | ||
empty_rb = pa.RecordBatch.from_arrays(arrays, names) | ||
|
||
schema = empty_rb.schema | ||
|
||
def generate_batches(iterator) -> Iterator[pa.RecordBatch]: | ||
for rb_ptrs in iterator: | ||
names, chunk_ptrs_list = rb_ptrs.to_ptrs() | ||
for chunk_ptrs in chunk_ptrs_list: | ||
yield pa.RecordBatch.from_arrays( | ||
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names | ||
) | ||
|
||
return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results)) | ||
|
||
|
||
def reconstruct_arrow(result: _ArrowInfos) -> pa.Table: | ||
import pyarrow as pa | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,77 @@ use connectorx::{prelude::*, sql::CXQuery}; | |
use fehler::throws; | ||
use libc::uintptr_t; | ||
use pyo3::prelude::*; | ||
use pyo3::pyclass; | ||
use pyo3::{PyAny, Python}; | ||
use std::convert::TryFrom; | ||
use std::sync::Arc; | ||
|
||
/// Python-exposed RecordBatch wrapper | ||
#[pyclass] | ||
pub struct PyRecordBatch(RecordBatch); | ||
|
||
/// Python-exposed iterator over RecordBatches | ||
#[pyclass(unsendable, module = "connectorx")] | ||
pub struct PyRecordBatchIterator(Box<dyn RecordBatchIterator>); | ||
|
||
#[pymethods] | ||
impl PyRecordBatch { | ||
pub fn num_rows(&self) -> usize { | ||
self.0.num_rows() | ||
} | ||
|
||
pub fn num_columns(&self) -> usize { | ||
self.0.num_columns() | ||
} | ||
|
||
#[throws(ConnectorXPythonError)] | ||
pub fn to_ptrs<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { | ||
let ptrs = py.allow_threads( | ||
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> { | ||
let rbs = vec![self.0.clone()]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this okay or do you suggest any workarounds?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can wrap over Also, since we are using an iterator to generate a batch at a time, we do no need to wrap over a vector of batches. |
||
Ok(to_ptrs(rbs)) | ||
}, | ||
)?; | ||
let obj: PyObject = ptrs.into_py(py); | ||
obj.into_bound(py) | ||
} | ||
} | ||
|
||
#[pymethods] | ||
impl PyRecordBatchIterator { | ||
|
||
#[throws(ConnectorXPythonError)] | ||
fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { | ||
let (rb, _) = self.0.get_schema(); | ||
let ptrs = py.allow_threads( | ||
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> { | ||
let rbs = vec![rb]; | ||
Ok(to_ptrs(rbs)) | ||
}, | ||
)?; | ||
let obj: PyObject = ptrs.into_py(py); | ||
obj.into_bound(py) | ||
} | ||
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { | ||
slf | ||
} | ||
|
||
fn __next__<'py>( | ||
mut slf: PyRefMut<'py, Self>, | ||
py: Python<'py>, | ||
) -> PyResult<Option<Py<PyRecordBatch>>> { | ||
match slf.0.next_batch() { | ||
Some(rb) => { | ||
let wrapped = PyRecordBatch(rb); | ||
let py_obj = Py::new(py, wrapped)?; | ||
Ok(Some(py_obj)) | ||
} | ||
|
||
None => Ok(None), | ||
} | ||
} | ||
} | ||
|
||
#[throws(ConnectorXPythonError)] | ||
pub fn write_arrow<'py>( | ||
py: Python<'py>, | ||
|
@@ -28,6 +95,30 @@ pub fn write_arrow<'py>( | |
obj.into_bound(py) | ||
} | ||
|
||
#[throws(ConnectorXPythonError)] | ||
pub fn get_arrow_rb_iter<'py>( | ||
py: Python<'py>, | ||
source_conn: &SourceConn, | ||
origin_query: Option<String>, | ||
queries: &[CXQuery<String>], | ||
pre_execution_queries: Option<&[String]>, | ||
batch_size: usize, | ||
) -> Bound<'py, PyAny> { | ||
let mut arrow_iter: Box<dyn RecordBatchIterator> = new_record_batch_iter( | ||
source_conn, | ||
origin_query, | ||
queries, | ||
batch_size, | ||
pre_execution_queries, | ||
); | ||
|
||
arrow_iter.prepare(); | ||
let py_rb_iter = PyRecordBatchIterator(arrow_iter); | ||
|
||
let obj: PyObject = py_rb_iter.into_py(py); | ||
obj.into_bound(py) | ||
} | ||
|
||
pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>) { | ||
if rbs.is_empty() { | ||
return (vec![], vec![]); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be using
arrow_stream
instead ofarrow_record_batches
for simplicity?