Skip to content

Commit 623ffb3

Browse files
JustBeYouamotl
authored andcommitted
Add support for multi-line SQL and commands.
1 parent 8c3de13 commit 623ffb3

File tree

5 files changed

+158
-51
lines changed

5 files changed

+158
-51
lines changed

CHANGES.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Unreleased
66
==========
77

88
- Fix inconsistent spacing around printed runtime. Thank you, @hammerhead.
9+
- Add support for multi-line input of commands and SQL statements for both
10+
copy-pasting inside the crash shell and input pipes into crash.
911

1012
2023/02/16 0.29.0
1113
=================

crate/crash/command.py

+30-43
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from getpass import getpass
3434
from operator import itemgetter
3535

36+
import sqlparse
3637
import urllib3
3738
from platformdirs import user_config_dir, user_data_dir
3839
from urllib3.exceptions import LocationParseError
@@ -176,31 +177,6 @@ def inner_fn(self, *args):
176177
return inner_fn
177178

178179

179-
def _parse_statements(lines):
180-
"""Return a generator of statements
181-
182-
Args: A list of strings that can contain one or more statements.
183-
Statements are separated using ';' at the end of a line
184-
Everything after the last ';' will be treated as the last statement.
185-
186-
>>> list(_parse_statements(['select * from ', 't1;', 'select name']))
187-
['select * from\\nt1', 'select name']
188-
189-
>>> list(_parse_statements(['select * from t1;', ' ']))
190-
['select * from t1']
191-
"""
192-
lines = (l.strip() for l in lines if l)
193-
lines = (l for l in lines if l and not l.startswith('--'))
194-
parts = []
195-
for line in lines:
196-
parts.append(line.rstrip(';'))
197-
if line.endswith(';'):
198-
yield '\n'.join(parts)
199-
parts[:] = []
200-
if parts:
201-
yield '\n'.join(parts)
202-
203-
204180
class CrateShell:
205181

206182
def __init__(self,
@@ -274,19 +250,28 @@ def pprint(self, rows, cols):
274250
self.get_num_columns())
275251
self.output_writer.write(result)
276252

277-
def process_iterable(self, stdin):
278-
any_statement = False
279-
for statement in _parse_statements(stdin):
280-
self._exec_and_print(statement)
281-
any_statement = True
282-
return any_statement
253+
def process_iterable(self, iterable):
254+
self._process_lines([line for text in iterable for line in text.split('\n')])
283255

284256
def process(self, text):
285-
if text.startswith('\\'):
286-
self._try_exec_cmd(text.lstrip('\\'))
287-
else:
288-
for statement in _parse_statements([text]):
289-
self._exec_and_print(statement)
257+
self._process_lines(text.split('\n'))
258+
259+
def _process_lines(self, lines):
260+
sql_lines = []
261+
for line in lines:
262+
line = line.strip()
263+
if line.startswith('\\'):
264+
self._process_sql('\n'.join(sql_lines))
265+
self._try_exec_cmd(line.lstrip('\\'))
266+
sql_lines = []
267+
else:
268+
sql_lines.append(line)
269+
self._process_sql('\n'.join(sql_lines))
270+
271+
def _process_sql(self, text):
272+
sql = sqlparse.format(text, strip_comments=False)
273+
for statement in sqlparse.split(sql):
274+
self._exec_and_print(statement)
290275

291276
def exit(self):
292277
self.close()
@@ -498,14 +483,15 @@ def stmt_type(statement):
498483
return re.findall(r'[\w]+', statement)[0].upper()
499484

500485

501-
def get_stdin():
486+
def get_lines_from_stdin():
502487
"""
503-
Get data from stdin, if any
488+
Get data line by line from stdin, if any
504489
"""
505-
if not sys.stdin.isatty():
506-
for line in sys.stdin:
507-
yield line
508-
return
490+
if sys.stdin.isatty():
491+
return
492+
493+
for line in sys.stdin:
494+
yield line
509495

510496

511497
def host_and_port(host_or_port):
@@ -622,7 +608,8 @@ def save_and_exit():
622608
cmd.process(args.command)
623609
save_and_exit()
624610

625-
if cmd.process_iterable(get_stdin()):
611+
if not sys.stdin.isatty():
612+
cmd.process_iterable(get_lines_from_stdin())
626613
save_and_exit()
627614

628615
from .repl import loop

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
'platformdirs<4',
3333
'prompt-toolkit>=3.0,<4',
3434
'tabulate>=0.9,<0.10',
35+
'sqlparse>=0.4.4,<0.5.0'
3536
]
3637

3738

tests/test_commands.py

+122-5
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,24 @@ def test_sql_comments(self):
229229
-- Another SELECT statement.
230230
SELECT 2;
231231
-- Yet another SELECT statement with an inline comment.
232-
-- Other than the regular comments, it gets passed through to the database server.
232+
-- Comments get passed through to the database server.
233233
SELECT /* this is a comment */ 3;
234+
SELECT /* this is a multi-line
235+
comment */ 4;
234236
"""
235237
cmd = CrateShell()
236238
cmd._exec_and_print = MagicMock()
237239
cmd.process_iterable(sql.splitlines())
238-
cmd._exec_and_print.assert_has_calls([
239-
call("SELECT 1"),
240-
call("SELECT 2"),
241-
call("SELECT /* this is a comment */ 3"),
240+
self.assertListEqual(cmd._exec_and_print.mock_calls, [
241+
call("-- Just a dummy SELECT statement.\nSELECT 1;"),
242+
call("-- Another SELECT statement.\nSELECT 2;"),
243+
call('\n'.join([
244+
"-- Yet another SELECT statement with an inline comment.",
245+
"-- Comments get passed through to the database server.",
246+
"SELECT /* this is a comment */ 3;"
247+
])
248+
),
249+
call('SELECT /* this is a multi-line\ncomment */ 4;')
242250
])
243251

244252
def test_js_comments(self):
@@ -262,3 +270,112 @@ def test_js_comments(self):
262270
cmd.process(sql)
263271
self.assertEqual(1, cmd._exec_and_print.call_count)
264272
self.assertIn("CREATE FUNCTION", cmd._exec_and_print.mock_calls[0].args[0])
273+
274+
275+
class MultipleStatementsTest(TestCase):
276+
277+
def test_single_line_multiple_sql_statements(self):
278+
cmd = CrateShell()
279+
cmd._exec_and_print = MagicMock()
280+
cmd.process("SELECT 1; SELECT 2; SELECT 3;")
281+
self.assertListEqual(cmd._exec_and_print.mock_calls, [
282+
call("SELECT 1;"),
283+
call("SELECT 2;"),
284+
call("SELECT 3;"),
285+
])
286+
287+
def test_multiple_lines_multiple_sql_statements(self):
288+
cmd = CrateShell()
289+
cmd._exec_and_print = MagicMock()
290+
cmd.process("SELECT 1;\nSELECT 2; SELECT 3;\nSELECT\n4;")
291+
self.assertListEqual(cmd._exec_and_print.mock_calls, [
292+
call("SELECT 1;"),
293+
call("SELECT 2;"),
294+
call("SELECT 3;"),
295+
call("SELECT\n4;"),
296+
])
297+
298+
def test_single_sql_statement_multiple_lines(self):
299+
"""When processing single SQL statements, new lines are preserved."""
300+
301+
cmd = CrateShell()
302+
cmd._exec_and_print = MagicMock()
303+
cmd.process("\nSELECT\n1\nWHERE\n2\n=\n3\n;\n")
304+
self.assertListEqual(cmd._exec_and_print.mock_calls, [
305+
call("SELECT\n1\nWHERE\n2\n=\n3\n;"),
306+
])
307+
308+
def test_multiple_commands_no_sql(self):
309+
cmd = CrateShell()
310+
cmd._try_exec_cmd = MagicMock()
311+
cmd._exec_and_print = MagicMock()
312+
cmd.process("\\?\n\\connect 127.0.0.1")
313+
cmd._try_exec_cmd.assert_has_calls([
314+
call("?"),
315+
call("connect 127.0.0.1")
316+
])
317+
cmd._exec_and_print.assert_not_called()
318+
319+
def test_commands_and_multiple_sql_statements_interleaved(self):
320+
"""Combine all test cases above to be sure everything integrates well."""
321+
322+
cmd = CrateShell()
323+
mock_manager = MagicMock()
324+
325+
cmd._try_exec_cmd = mock_manager.cmd
326+
cmd._exec_and_print = mock_manager.sql
327+
328+
cmd.process("""
329+
\\?
330+
SELECT 1
331+
WHERE 2 = 3; SELECT 4;
332+
\\connect 127.0.0.1
333+
SELECT 5
334+
WHERE 6 = 7;
335+
\\check
336+
""")
337+
338+
self.assertListEqual(mock_manager.mock_calls, [
339+
call.cmd("?"),
340+
call.sql('SELECT 1\nWHERE 2 = 3;'),
341+
call.sql('SELECT 4;'),
342+
call.cmd("connect 127.0.0.1"),
343+
call.sql('SELECT 5\nWHERE 6 = 7;'),
344+
call.cmd("check"),
345+
])
346+
347+
def test_comments_along_multiple_statements(self):
348+
"""Test multiple types of comments along multi-statement input."""
349+
350+
cmd = CrateShell()
351+
cmd._exec_and_print = MagicMock()
352+
353+
cmd.process("""
354+
-- Multiple statements and comments on same line
355+
356+
SELECT /* inner comment */ 1; /* this is a single-line comment */ SELECT /* inner comment */ 2;
357+
358+
-- Multiple statements on multiple lines with multi-line comments between them
359+
360+
SELECT /* inner comment */ 3; /* this is a
361+
multi-line comment */ SELECT /* inner comment */ 4;
362+
363+
-- Multiple statements on multiple lines with multi-line comments between and inside them
364+
365+
SELECT /* inner multi-line
366+
comment */ 5 /* this is a multi-line
367+
comment before statement end */; /* this is another multi-line
368+
comment */ SELECT /* inner multi-line
369+
comment */ 6;
370+
""")
371+
372+
self.assertListEqual(cmd._exec_and_print.mock_calls, [
373+
call('-- Multiple statements and comments on same line\n\nSELECT /* inner comment */ 1;'),
374+
call('/* this is a single-line comment */ SELECT /* inner comment */ 2;'),
375+
376+
call('-- Multiple statements on multiple lines with multi-line comments between them\n\nSELECT /* inner comment */ 3;'),
377+
call('/* this is a\nmulti-line comment */ SELECT /* inner comment */ 4;'),
378+
379+
call('-- Multiple statements on multiple lines with multi-line comments between and inside them\n\nSELECT /* inner multi-line\ncomment */ 5 /* this is a multi-line\ncomment before statement end */;'),
380+
call('/* this is another multi-line\ncomment */ SELECT /* inner multi-line\ncomment */ 6;')
381+
])

tests/test_integration.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from crate.crash.command import (
1414
CrateShell,
1515
_create_shell,
16+
get_lines_from_stdin,
1617
get_parser,
17-
get_stdin,
1818
host_and_port,
1919
main,
2020
noargs_command,
@@ -315,7 +315,7 @@ def test_multiline_stdin(self):
315315
316316
Newlines must be replaced with whitespaces
317317
"""
318-
stmt = ''.join(list(get_stdin())).replace('\n', ' ')
318+
stmt = ''.join(list(get_lines_from_stdin())).replace('\n', ' ')
319319
expected = ("create table test( d string ) "
320320
"clustered into 2 shards "
321321
"with (number_of_replicas=0)")
@@ -334,7 +334,7 @@ def test_multiline_stdin_delimiter(self):
334334
335335
Newlines must be replaced with whitespaces
336336
"""
337-
stmt = ''.join(list(get_stdin())).replace('\n', ' ')
337+
stmt = ''.join(list(get_lines_from_stdin())).replace('\n', ' ')
338338
expected = ("create table test( d string ) "
339339
"clustered into 2 shards "
340340
"with (number_of_replicas=0);")

0 commit comments

Comments
 (0)