Skip to content

Commit 72d3fd9

Browse files
committed
Auth: Add JWT authentication to client API and ctk shell
1 parent 5a5bba6 commit 72d3fd9

File tree

18 files changed

+211
-97
lines changed

18 files changed

+211
-97
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
deploy/start/resume and data import procedures using fluent API and CLI.
99
- CLI naming things: Rename `--cratedb-sqlalchemy-url` to `--sqlalchemy-url`
1010
and `--cratedb-http-url` to `--http-url`.
11+
- Auth: Added JWT authentication to client API and `ctk shell`.
1112

1213
## 2025/04/23 v0.0.32
1314
- MCP: Add subsystem providing a few server and client utilities through

cratedb_toolkit/cluster/core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import time
55
import typing as t
6+
from contextlib import nullcontext
67
from copy import deepcopy
78
from pathlib import Path
89

@@ -20,6 +21,7 @@
2021
OperationFailed,
2122
)
2223
from cratedb_toolkit.model import DatabaseAddress, InputOutputResource, TableAddress
24+
from cratedb_toolkit.util.client import jwt_token_patch
2325
from cratedb_toolkit.util.data import asbool
2426
from cratedb_toolkit.util.database import DatabaseAdapter
2527
from cratedb_toolkit.util.runtime import flexfun
@@ -131,6 +133,8 @@ def __init__(
131133
)
132134

133135
self.cm = CloudManager()
136+
self._jwt_ctx: t.ContextManager = nullcontext()
137+
self._client_bundle: t.Optional[ClientBundle] = None
134138

135139
def __enter__(self):
136140
"""Enter the context manager, ensuring the cluster is running."""
@@ -204,6 +208,7 @@ def probe(self) -> "ManagedCluster":
204208
self.cluster_id = self.info.cloud["id"]
205209
self.cluster_name = self.info.cloud["name"]
206210
self.address = DatabaseAddress.from_httpuri(self.info.cloud["url"])
211+
self._jwt_ctx = jwt_token_patch(self.info.jwt.token)
207212

208213
except (CroudException, DatabaseAddressMissingError) as ex:
209214
self.exists = False
@@ -415,8 +420,9 @@ def query(self, sql: str):
415420
# Ensure we have cluster connection details.
416421
if not self.info or not self.info.cloud.get("url"):
417422
self.probe()
418-
client_bundle = self.get_client_bundle()
419-
return client_bundle.adapter.run_sql(sql, records=True)
423+
with self._jwt_ctx:
424+
client_bundle = self.get_client_bundle()
425+
return client_bundle.adapter.run_sql(sql, records=True)
420426

421427

422428
@dataclasses.dataclass
@@ -429,6 +435,7 @@ class StandaloneCluster(ClusterBase):
429435
info: t.Optional[ClusterInformation] = None
430436
exists: bool = False
431437
_load_table_result: t.Optional[bool] = None
438+
_client_bundle: t.Optional[ClientBundle] = None
432439

433440
def __post_init__(self):
434441
super().__init__()

cratedb_toolkit/cluster/croud.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import json
2+
import logging
23
import os
34
import typing as t
45
from pathlib import Path
56

67
from cratedb_toolkit.exception import CroudException
78
from cratedb_toolkit.model import InputOutputResource, TableAddress
8-
from cratedb_toolkit.util.croud import CroudCall, CroudWrapper
9+
from cratedb_toolkit.util.croud import CroudCall, CroudClient, CroudWrapper
910

1011
# Default to a stable version if not specified in the environment.
1112
# TODO: Use `latest` CrateDB by default, or even `nightly`?
1213
DEFAULT_CRATEDB_VERSION = "5.10.4"
1314

1415

16+
logger = logging.getLogger(__name__)
17+
18+
1519
class CloudManager:
1620
"""
1721
A wrapper around the CrateDB Cloud API through the `croud` package, providing common methods.
@@ -374,3 +378,16 @@ def create_import_job(self, resource: InputOutputResource, target: TableAddress)
374378

375379
wr = CroudWrapper(call=call)
376380
return wr.invoke()
381+
382+
def get_jwt_token(self) -> t.Dict[str, str]:
383+
"""
384+
Retrieve per-cluster JWT token.
385+
"""
386+
client = CroudClient.create()
387+
data, errors = client.get(f"/api/v2/clusters/{self.cluster_id}/jwt/")
388+
errmsg = "Getting JWT token failed: Unknown error"
389+
if errors:
390+
errmsg = f"Getting JWT token failed: {errors}"
391+
if data is None:
392+
raise IOError(errmsg)
393+
return data

cratedb_toolkit/cluster/model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
logger = logging.getLogger(__name__)
1717

1818

19+
@dataclasses.dataclass
20+
class JwtResponse:
21+
expiry: str
22+
refresh: str
23+
token: str
24+
25+
def get_token(self):
26+
# TODO: Persist token across sessions.
27+
# TODO: Refresh automatically when expired.
28+
return self.token
29+
30+
1931
@dataclasses.dataclass
2032
class ClusterInformation:
2133
"""
@@ -76,6 +88,14 @@ def from_name(cls, cluster_name: str) -> "ClusterInformation":
7688
def asdict(self) -> t.Dict[str, t.Any]:
7789
return deepcopy(dataclasses.asdict(self))
7890

91+
@property
92+
def jwt(self) -> JwtResponse:
93+
"""
94+
Return per-cluster JWT token response.
95+
"""
96+
cc = CloudCluster(cluster_id=self.cloud_id)
97+
return JwtResponse(**cc.get_jwt_token())
98+
7999

80100
@dataclasses.dataclass
81101
class ClientBundle:

cratedb_toolkit/shell/cli.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
import click
24

35
from cratedb_toolkit import DatabaseCluster
@@ -13,6 +15,8 @@
1315
from cratedb_toolkit.util.cli import boot_click, docstring_format_verbatim
1416
from cratedb_toolkit.util.crash import get_crash_output_formats, run_crash
1517

18+
logger = logging.getLogger(__name__)
19+
1620
output_formats = get_crash_output_formats()
1721

1822

@@ -39,7 +43,7 @@ def help_cli():
3943
@option_username
4044
@option_password
4145
@option_schema
42-
@click.option("--command", type=str, required=False, help="SQL command")
46+
@click.option("--command", "-c", type=str, required=False, help="SQL command")
4347
@click.option(
4448
"--format",
4549
"format_",
@@ -82,10 +86,20 @@ def cli(
8286

8387
http_url = cluster.address.httpuri
8488

89+
is_cloud = cluster_id is not None or cluster_name is not None
90+
jwt_token = None
91+
if is_cloud:
92+
if username is not None:
93+
logger.info("Using username/password credentials for authentication")
94+
else:
95+
logger.info("Using JWT token for authentication")
96+
jwt_token = cluster.info.jwt.token
97+
8598
run_crash(
8699
hosts=http_url,
87100
username=username,
88101
password=password,
102+
jwt_token=jwt_token,
89103
schema=schema,
90104
command=command,
91105
output_format=format_,

cratedb_toolkit/util/client.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import contextlib
2+
from unittest import mock
3+
4+
import crate.client.http
5+
6+
7+
@contextlib.contextmanager
8+
def jwt_token_patch(jwt_token: str = None):
9+
with mock.patch.object(crate.client.http.Client, "_request", _mk_crate_client_server_request(jwt_token)):
10+
yield
11+
12+
13+
def _mk_crate_client_server_request(jwt_token: str = None):
14+
"""
15+
Create a monkey patched Server.request method to add the Authorization header for JWT token-based authentication.
16+
"""
17+
18+
_crate_client_server_request_dist = crate.client.http.Client._request
19+
20+
def _crate_client_server_request(self, *args, **kwargs):
21+
"""
22+
Monkey patch the Server.request method to add the Authorization header for JWT token-based authentication.
23+
24+
TODO: Submit to upstream libraries and programs `crate.client`, `crate.crash`, and `sqlalchemy-cratedb`.
25+
"""
26+
if jwt_token:
27+
kwargs.setdefault("headers", {})
28+
kwargs["headers"].update({"Authorization": "Bearer " + jwt_token})
29+
return _crate_client_server_request_dist(self, *args, **kwargs)
30+
31+
return _crate_client_server_request

cratedb_toolkit/util/crash.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@
66

77
from crate.crash.command import main
88

9+
from cratedb_toolkit.util.client import jwt_token_patch
10+
911

1012
def run_crash(
11-
hosts: str, command: str, output_format: str = None, schema: str = None, username: str = None, password: str = None
13+
hosts: str,
14+
command: str,
15+
output_format: str = None,
16+
schema: str = None,
17+
username: str = None,
18+
password: str = None,
19+
jwt_token: str = None,
1220
):
1321
"""
1422
Run the interactive CrateDB database shell using `crash`.
@@ -23,12 +31,11 @@ def run_crash(
2331
cmd += ["--command", command]
2432
if output_format:
2533
cmd += ["--format", output_format]
26-
with mock.patch.object(sys, "argv", cmd):
27-
password_context: contextlib.AbstractContextManager = contextlib.nullcontext()
28-
if password:
29-
password_context = mock.patch.dict(os.environ, {"CRATEPW": password})
30-
with password_context:
31-
main()
34+
password_context: contextlib.AbstractContextManager = contextlib.nullcontext()
35+
if password:
36+
password_context = mock.patch.dict(os.environ, {"CRATEPW": password})
37+
with mock.patch.object(sys, "argv", cmd), jwt_token_patch(jwt_token=jwt_token), password_context:
38+
main()
3239

3340

3441
def get_crash_output_formats() -> t.List[str]:

cratedb_toolkit/util/croud.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import croud.api
1616
import yaml
17+
from boltons.typeutils import classproperty
1718
from croud.config.configuration import Configuration
1819

1920
import cratedb_toolkit
@@ -151,24 +152,6 @@ def invoke_capturing(self, fun: t.Callable, *args: t.List[t.Any], **kwargs: t.Di
151152
buffer.seek(0)
152153
return buffer.read()
153154

154-
@property
155-
def headless_config(self) -> Configuration:
156-
cfg = Configuration("croud.yaml")
157-
if cfg._file_path.exists() and "CRATEDB_CLOUD_API_KEY" not in os.environ:
158-
return cfg
159-
160-
tmp_file = NamedTemporaryFile()
161-
tmp_path = Path(tmp_file.name)
162-
config = Configuration("headless.yaml", tmp_path)
163-
164-
# Get credentials from the environment.
165-
config.profile["key"] = os.environ.get("CRATEDB_CLOUD_API_KEY")
166-
config.profile["secret"] = os.environ.get("CRATEDB_CLOUD_API_SECRET")
167-
config.profile["organization-id"] = os.environ.get("CRATEDB_CLOUD_ORGANIZATION_ID")
168-
# config.profile["endpoint"] = os.environ.get("CRATEDB_CLOUD_ENDPOINT") # noqa: ERA001
169-
170-
return config
171-
172155
def run_croud_fun(self, fun: t.Callable, with_exceptions: bool = True):
173156
"""
174157
Wrapper function to call into `croud`, for catching and converging error messages.
@@ -200,14 +183,13 @@ def print_fun(levelname: str, *args, **kwargs):
200183
# https://stackoverflow.com/a/46481946
201184
levels = ["debug", "info", "warning", "error", "success"]
202185
with contextlib.ExitStack() as stack:
203-
# Patch all `print_*` functions.
186+
# Patch all `print_*` functions, as they would obstruct the output.
204187
for level in levels:
205188
p = patch(f"croud.printer.print_{level}", functools.partial(print_fun, level))
206189
stack.enter_context(p)
207190

208191
# Patch configuration.
209-
p = patch("croud.config._CONFIG", self.headless_config)
210-
stack.enter_context(p)
192+
stack.enter_context(headless_config())
211193

212194
# TODO: When aiming to disable wait-for-completion.
213195
"""
@@ -219,6 +201,15 @@ def print_fun(levelname: str, *args, **kwargs):
219201
return fun()
220202

221203

204+
@contextlib.contextmanager
205+
def headless_config():
206+
"""
207+
Patch the `croud.config` module to use a headless configuration.
208+
"""
209+
with patch("croud.config._CONFIG", CroudClient.get_headless_config):
210+
yield
211+
212+
222213
class CroudClient(croud.api.Client):
223214
"""
224215
A slightly modified `croud.api.Client` class, to inject a custom User-Agent header.
@@ -229,6 +220,42 @@ def __init__(self, *args, **kwargs):
229220
ua = f"{cratedb_toolkit.__appname__}/{cratedb_toolkit.__version__} Python/{python_version()}"
230221
self.session.headers["User-Agent"] = ua
231222

223+
@staticmethod
224+
def create() -> "croud.api.Client":
225+
"""
226+
Canonical factory method for creating a `croud.api.Client` instance.
227+
"""
228+
with headless_config():
229+
from croud.config import CONFIG
230+
231+
return croud.api.Client(
232+
CONFIG.endpoint,
233+
token=CONFIG.token,
234+
on_token=CONFIG.set_current_auth_token,
235+
key=CONFIG.key,
236+
secret=CONFIG.secret,
237+
region=CONFIG.region,
238+
sudo=False,
239+
)
240+
241+
@classproperty
242+
def get_headless_config(cls) -> Configuration:
243+
cfg = Configuration("croud.yaml")
244+
if cfg._file_path.exists() and "CRATEDB_CLOUD_API_KEY" not in os.environ:
245+
return cfg
246+
247+
tmp_file = NamedTemporaryFile()
248+
tmp_path = Path(tmp_file.name)
249+
config = Configuration("headless.yaml", tmp_path)
250+
251+
# Get credentials from the environment.
252+
config.profile["key"] = os.environ.get("CRATEDB_CLOUD_API_KEY")
253+
config.profile["secret"] = os.environ.get("CRATEDB_CLOUD_API_SECRET")
254+
config.profile["organization-id"] = os.environ.get("CRATEDB_CLOUD_ORGANIZATION_ID")
255+
# config.profile["endpoint"] = os.environ.get("CRATEDB_CLOUD_ENDPOINT") # noqa: ERA001
256+
257+
return config
258+
232259

233260
croud.api.Client = CroudClient
234261

cratedb_toolkit/util/database.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,17 @@ def run_sql_real(self, sql: str, parameters: t.Mapping[str, str] = None, records
120120
Invoke SQL statement, and return results.
121121
"""
122122
results = []
123-
with self.engine.connect() as connection:
124-
for statement in sqlparse.split(sql):
125-
if self.internal:
126-
statement += self.internal_tag
127-
result = connection.execute(sa.text(statement), parameters)
128-
data: t.Any
129-
if records:
130-
rows = result.mappings().fetchall()
131-
data = [dict(row.items()) for row in rows]
132-
else:
133-
data = result.fetchall()
134-
results.append(data)
123+
for statement in sqlparse.split(sql):
124+
if self.internal:
125+
statement += self.internal_tag
126+
result = self.connection.execute(sa.text(statement), parameters)
127+
data: t.Any
128+
if records:
129+
rows = result.mappings().fetchall()
130+
data = [dict(row.items()) for row in rows]
131+
else:
132+
data = result.fetchall()
133+
results.append(data)
135134

136135
# Backward-compatibility.
137136
if len(results) == 1:

doc/cluster/tutorial.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ CRATEDB_CLOUD_API_SECRET='<YOUR_API_SECRET_HERE>'
2121
# CrateDB Cloud cluster identifier (id or name).
2222
# CRATEDB_CLUSTER_ID='<YOUR_CLUSTER_ID_HERE>'
2323
CRATEDB_CLUSTER_NAME='<YOUR_CLUSTER_NAME_HERE>'
24-
25-
# Database credentials.
26-
CRATEDB_USERNAME='<YOUR_DATABASE_USERNAME_HERE>'
27-
CRATEDB_PASSWORD='<YOUR_DATABASE_PASSWORD_HERE>'
2824
EOF
2925
```
3026

0 commit comments

Comments
 (0)