Skip to content

Commit 488bab4

Browse files
committed
Allow tuple-valued params in read_sql[_query]
1 parent fa7e444 commit 488bab4

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

pandas-stubs/io/sql.pyi

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ def read_sql_query(
6666
con: _SQLConnection,
6767
index_col: str | list[str] | None = ...,
6868
coerce_float: bool = ...,
69-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
69+
params: list[Scalar]
70+
| tuple[Scalar, ...]
71+
| tuple[tuple[Scalar, ...], ...]
72+
| Mapping[str, Scalar]
73+
| Mapping[str, tuple[Scalar, ...]]
74+
| None = ...,
7075
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
7176
*,
7277
chunksize: int,
@@ -79,7 +84,12 @@ def read_sql_query(
7984
con: _SQLConnection,
8085
index_col: str | list[str] | None = ...,
8186
coerce_float: bool = ...,
82-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
87+
params: list[Scalar]
88+
| tuple[Scalar, ...]
89+
| tuple[tuple[Scalar, ...], ...]
90+
| Mapping[str, Scalar]
91+
| Mapping[str, tuple[Scalar, ...]]
92+
| None = ...,
8393
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
8494
chunksize: None = ...,
8595
dtype: DtypeArg | None = ...,
@@ -91,7 +101,12 @@ def read_sql(
91101
con: _SQLConnection,
92102
index_col: str | list[str] | None = ...,
93103
coerce_float: bool = ...,
94-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
104+
params: list[Scalar]
105+
| tuple[Scalar, ...]
106+
| tuple[tuple[Scalar, ...], ...]
107+
| Mapping[str, Scalar]
108+
| Mapping[str, tuple[Scalar, ...]]
109+
| None = ...,
95110
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
96111
columns: list[str] = ...,
97112
*,
@@ -105,7 +120,12 @@ def read_sql(
105120
con: _SQLConnection,
106121
index_col: str | list[str] | None = ...,
107122
coerce_float: bool = ...,
108-
params: list[Scalar] | tuple[Scalar, ...] | Mapping[str, Scalar] | None = ...,
123+
params: list[Scalar]
124+
| tuple[Scalar, ...]
125+
| tuple[tuple[Scalar, ...], ...]
126+
| Mapping[str, Scalar]
127+
| Mapping[str, tuple[Scalar, ...]]
128+
| None = ...,
109129
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
110130
columns: list[str] = ...,
111131
chunksize: None = ...,

tests/test_io.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,39 @@ def test_read_sql_query_via_sqlalchemy_engine_with_params():
12381238
engine.dispose()
12391239

12401240

1241+
@pytest.mark.skip(
1242+
reason="Only works in Postgres (and MySQL, but with different query syntax)"
1243+
)
1244+
def test_read_sql_query_via_sqlalchemy_engine_with_tuple_valued_params():
1245+
with ensure_clean() as path:
1246+
db_uri = "postgresql+psycopg2://postgres@localhost:5432/postgres"
1247+
engine = sqlalchemy.create_engine(db_uri)
1248+
1249+
check(
1250+
assert_type(
1251+
read_sql_query(
1252+
"select * from test where a in %(a)s",
1253+
con=engine,
1254+
params={"a": (1, 2)},
1255+
),
1256+
DataFrame,
1257+
),
1258+
DataFrame,
1259+
)
1260+
check(
1261+
assert_type(
1262+
read_sql_query(
1263+
"select * from test where a in %s",
1264+
con=engine,
1265+
params=((1, 2),),
1266+
),
1267+
DataFrame,
1268+
),
1269+
DataFrame,
1270+
)
1271+
engine.dispose()
1272+
1273+
12411274
def test_read_html():
12421275
check(assert_type(DF.to_html(), str), str)
12431276
with ensure_clean() as path:

0 commit comments

Comments
 (0)