Skip to content

Commit b2d390c

Browse files
BryanLeeNightMarcher
authored andcommitted
feat: upgrade sqlize_value with tortoise.converters, and replace it with escape
1 parent 81db3bf commit b2d390c

File tree

5 files changed

+105
-46
lines changed

5 files changed

+105
-46
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Generate sql and execute
116116
```sql
117117
UPDATE `account` SET extend =
118118
JSON_MERGE_PATCH(JSON_SET(JSON_REMOVE(COALESCE(extend, '{}'), '$.deprecated'), '$.last_login',CAST('{"ipv4": "209.182.101.161"}' AS JSON), '$.uuid','fd04f7f2-24fc-4a73-a1d7-b6e99a464c5f'), '{"updated_at": "2022-10-30 21:34:15", "info": {"online_sec": 636}}')
119-
, active=True, name='new_name'
119+
, active=1, name='new_name'
120120
WHERE `id`=8
121121
```
122122

@@ -162,7 +162,7 @@ Generate sql and execute
162162
```sql
163163
INSERT INTO `account_bak`
164164
(gender, locale, active, name, extend)
165-
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
165+
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
166166
FROM `account`
167167
WHERE `id` IN (4,5,6)
168168
```
@@ -185,8 +185,8 @@ Generate sql and execute
185185
JOIN (
186186
SELECT * FROM (
187187
VALUES
188-
ROW(7, False, False, 1, '{"test": 1, "debug": 0}'),
189-
ROW(15, False, True, 0, '{"test": 1, "debug": 0}')
188+
ROW(7, 0, 0, 1, '{"test": 1, "debug": 0}'),
189+
ROW(15, 0, 1, 0, '{"test": 1, "debug": 0}')
190190
) AS fly_table (id, deleted, active, gender, extend)
191191
) tmp ON `account`.id=tmp.id AND `account`.deleted=tmp.deleted
192192
SET `account`.active=tmp.active, `account`.gender=tmp.gender, `account`.extend=JSON_MERGE_PATCH(COALESCE(`account`.extend, '{}'), tmp.extend)

fastapi_esql/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from logging.config import dictConfig
22

3-
from tortoise.converters import escape_string
3+
from tortoise.converters import escape_item, escape_string
44
from tortoise.queryset import Q
55

66
from .const import (
@@ -23,7 +23,7 @@
2323
wrap_backticks,
2424
)
2525

26-
__version__ = "0.0.14"
26+
__version__ = "0.0.15"
2727

2828
__all__ = [
2929
"QsParsingError",
@@ -38,6 +38,7 @@
3838
"SQLizer",
3939
"Singleton",
4040
"convert_dicts",
41+
"escape_item",
4142
"escape_string",
4243
"timing",
4344
"wrap_backticks",

fastapi_esql/utils/sqlizer.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from enum import Enum
12
from logging import getLogger
23
from json import dumps
34
from typing import Any, Dict, List, Optional, Union
45

56
from tortoise import Model, __version__ as tortoise_version
7+
from tortoise.converters import escape_item
68
from tortoise.queryset import Q
79
from tortoise.query_utils import QueryModifier
810

@@ -35,10 +37,10 @@ def __init__(self, field: str, whens: dict, default=None):
3537
@property
3638
def sql(self):
3739
whens = " ".join(
38-
f"WHEN {k} THEN {SQLizer.sqlize_value(v)}"
40+
f"WHEN {k} THEN {SQLizer.escape(v)}"
3941
for k, v in self.whens.items()
4042
)
41-
else_ = " ELSE " + SQLizer.sqlize_value(self.default) if self.default is not None else ""
43+
else_ = " ELSE " + SQLizer.escape(self.default) if self.default is not None else ""
4244
return f"CASE {self.field} {whens}{else_} END"
4345

4446

@@ -89,25 +91,51 @@ def resolve_orders(cls, orders: List[str]) -> str:
8991
return ", ".join(orders_)
9092

9193
@classmethod
92-
def sqlize_value(cls, value, to_json=False) -> str:
94+
def escape(cls, obj, to_json=False, ver=2) -> str:
95+
if ver == 1:
96+
return cls._escape_v1(obj, to_json)
97+
elif ver == 2:
98+
return cls._escape_v2(obj, to_json)
99+
return cls._escape_v1(obj, to_json)
100+
101+
@classmethod
102+
def _escape_v1(cls, obj, to_json=False):
93103
"""
94-
Works like aiomysql.connection.Connection.escape
104+
Original DIY `escape` method
95105
"""
96-
if value is None:
106+
if obj is None:
97107
return "NULL"
98-
elif isinstance(value, (Cases, RawSQL)):
99-
return value.sql
100-
elif isinstance(value, (int, float, bool)):
101-
return f"{value}"
102-
elif isinstance(value, (dict, list, tuple)):
103-
dumped = dumps(value, ensure_ascii=False)
108+
elif isinstance(obj, (Cases, RawSQL)):
109+
return obj.sql
110+
elif isinstance(obj, (int, float, bool)):
111+
return f"{obj}"
112+
elif isinstance(obj, (dict, list, tuple)):
113+
dumped = dumps(obj, ensure_ascii=False)
104114
if to_json:
105115
return f"CAST('{dumped}' AS JSON)"
106116
# Same with above line
107117
# return f"JSON_EXTRACT('{dumped}', '$')"
108118
return f"'{dumped}'"
109119
else:
110-
return f"'{value}'"
120+
return f"'{obj}'"
121+
122+
@classmethod
123+
def _escape_v2(cls, obj, to_json=False):
124+
"""
125+
Escape whatever value you pass to it.
126+
Partially copied from aiomysql.connection.Connection.escape
127+
"""
128+
if isinstance(obj, (Cases, RawSQL)):
129+
return obj.sql
130+
elif isinstance(obj, Enum):
131+
return cls._escape_v2(obj.value)
132+
elif isinstance(obj, (dict, list, tuple)):
133+
dumped = dumps(obj, ensure_ascii=False)
134+
if to_json:
135+
return f"CAST('{dumped}' AS JSON)"
136+
return f"'{dumped}'"
137+
else:
138+
return escape_item(obj, "utf8mb4")
111139

112140
@classmethod
113141
def select_custom_fields(
@@ -178,17 +206,17 @@ def update_json_field(
178206
json_obj = f"JSON_REMOVE({json_obj}, {rps})"
179207
if path_value_dict:
180208
pvs = [
181-
f"'{path}',{cls.sqlize_value(value, to_json=True)}"
209+
f"'{path}',{cls.escape(value, to_json=True)}"
182210
for (path, value) in path_value_dict.items()
183211
]
184212
json_obj = f"JSON_SET({json_obj}, {', '.join(pvs)})"
185213
if merge_dict:
186-
json_obj = f"JSON_MERGE_PATCH({json_obj}, {cls.sqlize_value(merge_dict)})"
214+
json_obj = f"JSON_MERGE_PATCH({json_obj}, {cls.escape(merge_dict)})"
187215

188216
assign_field_dict = assign_field_dict or {}
189217
assign_fields = []
190218
for k, v in assign_field_dict.items():
191-
assign_fields.append(f"{k}={cls.sqlize_value(v)}")
219+
assign_fields.append(f"{k}={cls.escape(v)}")
192220
assign_field = ", ".join(assign_fields) if assign_fields else None
193221

194222
sql = """
@@ -220,7 +248,7 @@ def upsert_on_duplicate(
220248
raise WrongParamsError("Parameters `table`, `dicts`, `insert_fields` are required")
221249

222250
values = [
223-
f" ({', '.join(cls.sqlize_value(d.get(f)) for f in insert_fields)})"
251+
f" ({', '.join(cls.escape(d.get(f)) for f in insert_fields)})"
224252
for d in dicts
225253
]
226254
# NOTE Beginning with MySQL 8.0.19, it is possible to use an alias for the row
@@ -279,7 +307,7 @@ def insert_into_select(
279307
assign_fields = []
280308
for k, v in assign_field_dict.items():
281309
fields.append(k)
282-
assign_fields.append(f"{cls.sqlize_value(v)} {k}")
310+
assign_fields.append(f"{cls.escape(v)} {k}")
283311

284312
sql = f"""
285313
INSERT INTO {wrap_backticks(to_table or table)}
@@ -304,14 +332,14 @@ def build_fly_table(
304332

305333
if using_values:
306334
rows = [
307-
f" ROW({', '.join(cls.sqlize_value(d.get(f)) for f in fields)})"
335+
f" ROW({', '.join(cls.escape(d.get(f)) for f in fields)})"
308336
for d in dicts
309337
]
310338
values = "VALUES\n" + ",\n".join(rows)
311339
table = f"fly_table ({', '.join(fields)})"
312340
else:
313341
rows = [
314-
f"SELECT {', '.join(f'{cls.sqlize_value(d.get(f))} {f}' for f in fields)}"
342+
f"SELECT {', '.join(f'{cls.escape(d.get(f))} {f}' for f in fields)}"
315343
for d in dicts
316344
]
317345
values = "\n UNION\n ".join(rows)
@@ -354,3 +382,6 @@ def bulk_update_from_dicts(
354382
"""
355383
logger.debug(sql)
356384
return sql
385+
386+
387+
SQLizer.sqlize_value = SQLizer.escape

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "fastapi-efficient-sql"
3-
version = "0.0.14"
3+
version = "0.0.15"
44
description = "Generate bulk DML SQL and execute them based on Tortoise ORM and mysql8.0+, and integrated with FastAPI."
55
authors = ["BryanLee <bryanlee@126.com>"]
66
keywords = ["sql", "fastapi", "tortoise-orm", "mysql8", "bulk-operation"]

tests/test_sqlizer.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,53 @@ def test_resolve_orders(self):
6868
orders = SQLizer.resolve_orders(["-created_at", "name"])
6969
assert orders == "created_at DESC, name ASC"
7070

71-
def test_sqlize_value(self):
72-
assert SQLizer.sqlize_value(None) == "NULL"
71+
def test_sqlize_value_v1(self):
72+
assert SQLizer.sqlize_value(None, ver=1) == "NULL"
7373

7474
raw_sql = RawSQL("statement")
75-
assert SQLizer.sqlize_value(raw_sql) == raw_sql.sql
75+
assert SQLizer.sqlize_value(raw_sql, ver=1) == raw_sql.sql
7676
cases = Cases("is_ok", {0: "No", 1: "Yes"})
77-
assert SQLizer.sqlize_value(cases) == cases.sql
77+
assert SQLizer.sqlize_value(cases, ver=1) == cases.sql
7878

79-
assert SQLizer.sqlize_value(1024) == "1024"
80-
assert SQLizer.sqlize_value(0.125) == "0.125"
81-
assert SQLizer.sqlize_value(True) == "True"
79+
assert SQLizer.sqlize_value(GenderEnum.unknown, ver=1) == "0"
80+
assert SQLizer.sqlize_value(LocaleEnum.zh_CN, ver=1) == "'zh_CN'"
81+
82+
assert SQLizer.sqlize_value(1024, ver=1) == "1024"
83+
assert SQLizer.sqlize_value(0.125, ver=1) == "0.125"
84+
assert SQLizer.sqlize_value(True, ver=1) == "True"
85+
86+
assert (
87+
SQLizer.sqlize_value({"gender": 0, "name": "羊淑兰"}, to_json=True, ver=1)
88+
== """CAST('{"gender": 0, "name": "羊淑兰"}' AS JSON)"""
89+
)
90+
assert SQLizer.sqlize_value([1, 2, 4], ver=1) == "'[1, 2, 4]'"
91+
assert SQLizer.sqlize_value(("a", "b", "c"), ver=1) == """'["a", "b", "c"]'"""
92+
93+
assert SQLizer.sqlize_value(datetime(2023, 1, 1, 12, 30), ver=1) == "'2023-01-01 12:30:00'"
94+
95+
def test_escape_v2(self):
96+
assert SQLizer.escape(None, ver=2) == "NULL"
97+
98+
raw_sql = RawSQL("statement")
99+
assert SQLizer.escape(raw_sql, ver=2) == raw_sql.sql
100+
cases = Cases("is_ok", {0: "No", 1: "Yes"})
101+
assert SQLizer.escape(cases, ver=2) == cases.sql
102+
103+
assert SQLizer.escape(GenderEnum.unknown, ver=2) == "0"
104+
assert SQLizer.escape(LocaleEnum.zh_CN, ver=2) == "'zh_CN'"
105+
106+
assert SQLizer.escape(1024, ver=2) == "1024"
107+
assert SQLizer.escape(0.125, ver=2) == "0.125"
108+
assert SQLizer.escape(True, ver=2) == "1"
82109

83110
assert (
84-
SQLizer.sqlize_value({"gender": 0, "name": "羊淑兰"}, to_json=True)
111+
SQLizer.escape({"gender": 0, "name": "羊淑兰"}, to_json=True, ver=2)
85112
== """CAST('{"gender": 0, "name": "羊淑兰"}' AS JSON)"""
86113
)
87-
assert SQLizer.sqlize_value([1, 2, 4]) == "'[1, 2, 4]'"
88-
assert SQLizer.sqlize_value(("a", "b", "c")) == """'["a", "b", "c"]'"""
114+
assert SQLizer.escape([1, 2, 4], ver=2) == "'[1, 2, 4]'"
115+
assert SQLizer.escape(("a", "b", "c"), ver=2) == """'["a", "b", "c"]'"""
89116

90-
assert SQLizer.sqlize_value(datetime(2023, 1, 1, 12, 30)) == "'2023-01-01 12:30:00'"
117+
assert SQLizer.escape(datetime(2023, 1, 1, 12, 30), ver=2) == "'2023-01-01T12:30:00'"
91118

92119
def test_select_custom_fields(self):
93120
with self.assertRaises(WrongParamsError):
@@ -227,7 +254,7 @@ def test_update_json_field(self):
227254
assert sql == """
228255
UPDATE `account` SET extend =
229256
JSON_MERGE_PATCH(JSON_SET(JSON_REMOVE(COALESCE(extend, '{}'), '$.deprecated'), '$.last_login',CAST('{"ipv4": "209.182.101.161"}' AS JSON), '$.uuid','fd04f7f2-24fc-4a73-a1d7-b6e99a464c5f'), '{"updated_at": "2022-10-30 21:34:15", "info": {"online_sec": 636}}')
230-
, active=True, name='new_name'
257+
, active=1, name='new_name'
231258
WHERE `id`=8
232259
"""
233260

@@ -326,7 +353,7 @@ def test_insert_into_select(self):
326353
assert archive_sql == """
327354
INSERT INTO `account_bak`
328355
(gender, locale, active, name, extend)
329-
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
356+
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
330357
FROM `account`
331358
WHERE `id` IN (4,5,6)
332359
"""
@@ -346,7 +373,7 @@ def test_insert_into_select(self):
346373
assert copy_sql == """
347374
INSERT INTO `account`
348375
(gender, locale, active, name, extend)
349-
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, False active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
376+
SELECT gender, CASE id WHEN 3 THEN 'zh_CN' WHEN 4 THEN 'en_US' WHEN 5 THEN 'fr_FR' ELSE '' END locale, 0 active, CONCAT(LEFT(name, 26), ' [NEW]') name, '{}' extend
350377
FROM `account`
351378
WHERE `id` IN (4,5,6)
352379
"""
@@ -368,9 +395,9 @@ def test_build_fly_table(self):
368395
)
369396
assert old_sql == """
370397
SELECT * FROM (
371-
SELECT 7 id, False active, 1 gender
398+
SELECT 7 id, 0 active, 1 gender
372399
UNION
373-
SELECT 15 id, True active, 0 gender
400+
SELECT 15 id, 1 active, 0 gender
374401
) AS fly_table"""
375402

376403
new_sql = SQLizer.build_fly_table(
@@ -384,8 +411,8 @@ def test_build_fly_table(self):
384411
assert new_sql == """
385412
SELECT * FROM (
386413
VALUES
387-
ROW(7, False, 1),
388-
ROW(15, True, 0)
414+
ROW(7, 0, 1),
415+
ROW(15, 1, 0)
389416
) AS fly_table (id, active, gender)"""
390417

391418
def test_bulk_update_from_dicts(self):
@@ -412,8 +439,8 @@ def test_bulk_update_from_dicts(self):
412439
JOIN (
413440
SELECT * FROM (
414441
VALUES
415-
ROW(7, False, False, 1, '{"test": 1, "debug": 0}'),
416-
ROW(15, False, True, 0, '{"test": 1, "debug": 0}')
442+
ROW(7, 0, 0, 1, '{"test": 1, "debug": 0}'),
443+
ROW(15, 0, 1, 0, '{"test": 1, "debug": 0}')
417444
) AS fly_table (id, deleted, active, gender, extend)
418445
) tmp ON `account`.id=tmp.id AND `account`.deleted=tmp.deleted
419446
SET `account`.active=tmp.active, `account`.gender=tmp.gender, `account`.extend=JSON_MERGE_PATCH(COALESCE(`account`.extend, '{}'), tmp.extend)

0 commit comments

Comments
 (0)