Skip to content

Commit 3644a97

Browse files
authored
feat: add HTTP and HTTPS to hive (#385)
* feat: add https protocol * support HTTP
1 parent d6e7140 commit 3644a97

File tree

3 files changed

+104
-3
lines changed

3 files changed

+104
-3
lines changed

pyhive/hive.py

+76-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
from __future__ import absolute_import
99
from __future__ import unicode_literals
1010

11+
import base64
1112
import datetime
1213
import re
1314
from decimal import Decimal
15+
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
16+
1417

1518
from TCLIService import TCLIService
1619
from TCLIService import constants
@@ -25,6 +28,7 @@
2528
import getpass
2629
import logging
2730
import sys
31+
import thrift.transport.THttpClient
2832
import thrift.protocol.TBinaryProtocol
2933
import thrift.transport.TSocket
3034
import thrift.transport.TTransport
@@ -38,6 +42,12 @@
3842

3943
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
4044

45+
ssl_cert_parameter_map = {
46+
"none": CERT_NONE,
47+
"optional": CERT_OPTIONAL,
48+
"required": CERT_REQUIRED,
49+
}
50+
4151

4252
def _parse_timestamp(value):
4353
if value:
@@ -97,9 +107,21 @@ def connect(*args, **kwargs):
97107
class Connection(object):
98108
"""Wraps a Thrift session"""
99109

100-
def __init__(self, host=None, port=None, username=None, database='default', auth=None,
101-
configuration=None, kerberos_service_name=None, password=None,
102-
thrift_transport=None):
110+
def __init__(
111+
self,
112+
host=None,
113+
port=None,
114+
scheme=None,
115+
username=None,
116+
database='default',
117+
auth=None,
118+
configuration=None,
119+
kerberos_service_name=None,
120+
password=None,
121+
check_hostname=None,
122+
ssl_cert=None,
123+
thrift_transport=None
124+
):
103125
"""Connect to HiveServer2
104126
105127
:param host: What host HiveServer2 runs on
@@ -116,6 +138,32 @@ def __init__(self, host=None, port=None, username=None, database='default', auth
116138
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
117139
/impala/_thrift_api.py#L152-L160
118140
"""
141+
if scheme in ("https", "http") and thrift_transport is None:
142+
ssl_context = None
143+
if scheme == "https":
144+
ssl_context = create_default_context()
145+
ssl_context.check_hostname = check_hostname == "true"
146+
ssl_cert = ssl_cert or "none"
147+
ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE)
148+
thrift_transport = thrift.transport.THttpClient.THttpClient(
149+
uri_or_host=f"{scheme}://{host}:{port}/cliservice/",
150+
ssl_context=ssl_context,
151+
)
152+
153+
if auth in ("BASIC", "NOSASL", "NONE", None):
154+
# Always needs the Authorization header
155+
self._set_authorization_header(thrift_transport, username, password)
156+
elif auth == "KERBEROS" and kerberos_service_name:
157+
self._set_kerberos_header(thrift_transport, kerberos_service_name, host)
158+
else:
159+
raise ValueError(
160+
"Authentication is not valid use one of:"
161+
"BASIC, NOSASL, KERBEROS, NONE"
162+
)
163+
host, port, auth, kerberos_service_name, password = (
164+
None, None, None, None, None
165+
)
166+
119167
username = username or getpass.getuser()
120168
configuration = configuration or {}
121169

@@ -207,6 +255,31 @@ def sasl_factory():
207255
self._transport.close()
208256
raise
209257

258+
@staticmethod
259+
def _set_authorization_header(transport, username=None, password=None):
260+
username = username or "user"
261+
password = password or "pass"
262+
auth_credentials = f"{username}:{password}".encode("UTF-8")
263+
auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode(
264+
"UTF-8"
265+
)
266+
transport.setCustomHeaders(
267+
{"Authorization": f"Basic {auth_credentials_base64}"}
268+
)
269+
270+
@staticmethod
271+
def _set_kerberos_header(transport, kerberos_service_name, host) -> None:
272+
import kerberos
273+
274+
__, krb_context = kerberos.authGSSClientInit(
275+
service=f"{kerberos_service_name}@{host}"
276+
)
277+
kerberos.authGSSClientClean(krb_context, "")
278+
kerberos.authGSSClientStep(krb_context, "")
279+
auth_header = kerberos.authGSSClientResponse(krb_context)
280+
281+
transport.setCustomHeaders({"Authorization": f"Negotiate {auth_header}"})
282+
210283
def __enter__(self):
211284
"""Transport should already be opened by __init__"""
212285
return self

pyhive/sqlalchemy_hive.py

+26
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,29 @@ def _check_unicode_returns(self, connection, additional_tests=None):
374374
def _check_unicode_description(self, connection):
375375
# We decode everything as UTF-8
376376
return True
377+
378+
379+
class HiveHTTPDialect(HiveDialect):
380+
381+
name = "hive"
382+
scheme = "http"
383+
driver = "rest"
384+
385+
def create_connect_args(self, url):
386+
kwargs = {
387+
"host": url.host,
388+
"port": url.port or 10000,
389+
"scheme": self.scheme,
390+
"username": url.username or None,
391+
"password": url.password or None,
392+
}
393+
if url.query:
394+
kwargs.update(url.query)
395+
return [], kwargs
396+
return ([], kwargs)
397+
398+
399+
class HiveHTTPSDialect(HiveHTTPDialect):
400+
401+
name = "hive"
402+
scheme = "https"

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def run_tests(self):
6666
entry_points={
6767
'sqlalchemy.dialects': [
6868
'hive = pyhive.sqlalchemy_hive:HiveDialect',
69+
"hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect",
70+
"hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect",
6971
'presto = pyhive.sqlalchemy_presto:PrestoDialect',
7072
'trino = pyhive.sqlalchemy_trino:TrinoDialect',
7173
],

0 commit comments

Comments
 (0)