8
8
from __future__ import absolute_import
9
9
from __future__ import unicode_literals
10
10
11
+ import base64
11
12
import datetime
12
13
import re
13
14
from decimal import Decimal
15
+ from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED , create_default_context
16
+
14
17
15
18
from TCLIService import TCLIService
16
19
from TCLIService import constants
25
28
import getpass
26
29
import logging
27
30
import sys
31
+ import thrift .transport .THttpClient
28
32
import thrift .protocol .TBinaryProtocol
29
33
import thrift .transport .TSocket
30
34
import thrift .transport .TTransport
38
42
39
43
_TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)' )
40
44
45
+ ssl_cert_parameter_map = {
46
+ "none" : CERT_NONE ,
47
+ "optional" : CERT_OPTIONAL ,
48
+ "required" : CERT_REQUIRED ,
49
+ }
50
+
41
51
42
52
def _parse_timestamp (value ):
43
53
if value :
@@ -97,9 +107,21 @@ def connect(*args, **kwargs):
97
107
class Connection (object ):
98
108
"""Wraps a Thrift session"""
99
109
100
- def __init__ (self , host = None , port = None , username = None , database = 'default' , auth = None ,
101
- configuration = None , kerberos_service_name = None , password = None ,
102
- thrift_transport = None ):
110
+ def __init__ (
111
+ self ,
112
+ host = None ,
113
+ port = None ,
114
+ scheme = None ,
115
+ username = None ,
116
+ database = 'default' ,
117
+ auth = None ,
118
+ configuration = None ,
119
+ kerberos_service_name = None ,
120
+ password = None ,
121
+ check_hostname = None ,
122
+ ssl_cert = None ,
123
+ thrift_transport = None
124
+ ):
103
125
"""Connect to HiveServer2
104
126
105
127
:param host: What host HiveServer2 runs on
@@ -116,6 +138,32 @@ def __init__(self, host=None, port=None, username=None, database='default', auth
116
138
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
117
139
/impala/_thrift_api.py#L152-L160
118
140
"""
141
+ if scheme in ("https" , "http" ) and thrift_transport is None :
142
+ ssl_context = None
143
+ if scheme == "https" :
144
+ ssl_context = create_default_context ()
145
+ ssl_context .check_hostname = check_hostname == "true"
146
+ ssl_cert = ssl_cert or "none"
147
+ ssl_context .verify_mode = ssl_cert_parameter_map .get (ssl_cert , CERT_NONE )
148
+ thrift_transport = thrift .transport .THttpClient .THttpClient (
149
+ uri_or_host = f"{ scheme } ://{ host } :{ port } /cliservice/" ,
150
+ ssl_context = ssl_context ,
151
+ )
152
+
153
+ if auth in ("BASIC" , "NOSASL" , "NONE" , None ):
154
+ # Always needs the Authorization header
155
+ self ._set_authorization_header (thrift_transport , username , password )
156
+ elif auth == "KERBEROS" and kerberos_service_name :
157
+ self ._set_kerberos_header (thrift_transport , kerberos_service_name , host )
158
+ else :
159
+ raise ValueError (
160
+ "Authentication is not valid use one of:"
161
+ "BASIC, NOSASL, KERBEROS, NONE"
162
+ )
163
+ host , port , auth , kerberos_service_name , password = (
164
+ None , None , None , None , None
165
+ )
166
+
119
167
username = username or getpass .getuser ()
120
168
configuration = configuration or {}
121
169
@@ -207,6 +255,31 @@ def sasl_factory():
207
255
self ._transport .close ()
208
256
raise
209
257
258
+ @staticmethod
259
+ def _set_authorization_header (transport , username = None , password = None ):
260
+ username = username or "user"
261
+ password = password or "pass"
262
+ auth_credentials = f"{ username } :{ password } " .encode ("UTF-8" )
263
+ auth_credentials_base64 = base64 .standard_b64encode (auth_credentials ).decode (
264
+ "UTF-8"
265
+ )
266
+ transport .setCustomHeaders (
267
+ {"Authorization" : f"Basic { auth_credentials_base64 } " }
268
+ )
269
+
270
+ @staticmethod
271
+ def _set_kerberos_header (transport , kerberos_service_name , host ) -> None :
272
+ import kerberos
273
+
274
+ __ , krb_context = kerberos .authGSSClientInit (
275
+ service = f"{ kerberos_service_name } @{ host } "
276
+ )
277
+ kerberos .authGSSClientClean (krb_context , "" )
278
+ kerberos .authGSSClientStep (krb_context , "" )
279
+ auth_header = kerberos .authGSSClientResponse (krb_context )
280
+
281
+ transport .setCustomHeaders ({"Authorization" : f"Negotiate { auth_header } " })
282
+
210
283
def __enter__ (self ):
211
284
"""Transport should already be opened by __init__"""
212
285
return self
0 commit comments