Skip to content

Commit 8eb0aea

Browse files
authored
Deal with datetime objects (#327)
* Deal with datetime objects * Allow changing milli/microseconds precision * Reference self instead of reinstantiating ParamEscaper
1 parent 3ac95b5 commit 8eb0aea

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

pyhive/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import abc
1515
import collections
1616
import time
17+
import datetime
1718
from future.utils import with_metaclass
1819
from itertools import islice
1920

@@ -201,6 +202,10 @@ def __cmp__(self, other):
201202

202203

203204
class ParamEscaper(object):
205+
_DATE_FORMAT = "%Y-%m-%d"
206+
_TIME_FORMAT = "%H:%M:%S.%f"
207+
_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)
208+
204209
def escape_args(self, parameters):
205210
if isinstance(parameters, dict):
206211
return {k: self.escape_item(v) for k, v in parameters.items()}
@@ -228,6 +233,11 @@ def escape_sequence(self, item):
228233
l = map(str, map(self.escape_item, item))
229234
return '(' + ','.join(l) + ')'
230235

236+
def escape_datetime(self, item, format, cutoff=0):
237+
dt_str = item.strftime(format)
238+
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
239+
return "'{}'".format(formatted)
240+
231241
def escape_item(self, item):
232242
if item is None:
233243
return 'NULL'
@@ -237,6 +247,10 @@ def escape_item(self, item):
237247
return self.escape_string(item)
238248
elif isinstance(item, collections.Iterable):
239249
return self.escape_sequence(item)
250+
elif isinstance(item, datetime.datetime):
251+
return self.escape_datetime(item, self._DATETIME_FORMAT)
252+
elif isinstance(item, datetime.date):
253+
return self.escape_datetime(item, self._DATE_FORMAT)
240254
else:
241255
raise exc.ProgrammingError("Unsupported object {}".format(item))
242256

pyhive/presto.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pyhive.exc import * # noqa
1616
import base64
1717
import getpass
18+
import datetime
1819
import logging
1920
import requests
2021
from requests.auth import HTTPBasicAuth
@@ -32,7 +33,16 @@
3233
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
3334

3435
_logger = logging.getLogger(__name__)
35-
_escaper = common.ParamEscaper()
36+
37+
38+
class PrestoParamEscaper(common.ParamEscaper):
39+
def escape_datetime(self, item, format):
40+
_type = "timestamp" if isinstance(item, datetime.datetime) else "date"
41+
formatted = super(PrestoParamEscaper, self).escape_datetime(item, format, 3)
42+
return "{} {}".format(_type, formatted)
43+
44+
45+
_escaper = PrestoParamEscaper()
3646

3747

3848
def connect(*args, **kwargs):

pyhive/tests/test_common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import absolute_import
33
from __future__ import unicode_literals
44
from pyhive import common
5-
5+
import datetime
66
import unittest
77

88

@@ -34,3 +34,7 @@ def test_escape_args(self):
3434
("('a','b','c')",))
3535
self.assertEqual(escaper.escape_args((['你好', 'b', 'c'],)),
3636
("('你好','b','c')",))
37+
self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
38+
("'2020-04-17'",))
39+
self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
40+
("'2020-04-17 12:00:00.123456'",))

pyhive/tests/test_presto.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pyhive.tests.dbapi_test_case import with_cursor
1818
import mock
1919
import unittest
20+
import datetime
2021

2122
_HOST = 'localhost'
2223
_PORT = '8080'
@@ -32,6 +33,14 @@ def test_bad_protocol(self):
3233
self.assertRaisesRegexp(ValueError, 'Protocol must be',
3334
lambda: presto.connect('localhost', protocol='nonsense').cursor())
3435

36+
def test_escape_args(self):
37+
escaper = presto.PrestoParamEscaper()
38+
39+
self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
40+
("date '2020-04-17'",))
41+
self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
42+
("timestamp '2020-04-17 12:00:00.123'",))
43+
3544
@with_cursor
3645
def test_description(self, cursor):
3746
cursor.execute('SELECT 1 AS foobar FROM one_row')

0 commit comments

Comments
 (0)