Skip to content

Commit e2a406f

Browse files
authored
Fix(duckdb): Use CREATE OR REPLACE when registering secrets on cursor init to prevent an 'already exists' error (#4974)
1 parent ed8b404 commit e2a406f

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

sqlmesh/core/config/connection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ def init(cursor: duckdb.DuckDBPyConnection) -> None:
380380
if secret_settings:
381381
secret_clause = ", ".join(secret_settings)
382382
try:
383-
cursor.execute(f"CREATE SECRET {secret_name} ({secret_clause});")
383+
cursor.execute(
384+
f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});"
385+
)
384386
except Exception as e:
385387
raise ConfigError(f"Failed to create secret: {e}")
386388

tests/core/engine_adapter/integration/test_integration_duckdb.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from threading import current_thread, Thread
44
import random
55
from sqlglot import exp
6+
from pathlib import Path
7+
from concurrent.futures import ThreadPoolExecutor, as_completed
68

79
from sqlmesh.core.config.connection import DuckDBConnectionConfig
810
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
@@ -11,7 +13,7 @@
1113

1214

1315
@pytest.mark.parametrize("database", [None, "db.db"])
14-
def test_multithread_concurrency(tmp_path, database: t.Optional[str]):
16+
def test_multithread_concurrency(tmp_path: Path, database: t.Optional[str]):
1517
num_threads = 100
1618

1719
if database:
@@ -72,3 +74,35 @@ def read_from_thread():
7274

7375
tables = adapter.fetchall("show tables")
7476
assert len(tables) == num_threads + 1
77+
78+
79+
def test_secret_registration_from_multiple_connections(tmp_path: Path):
80+
database = str(tmp_path / "db.db")
81+
82+
config = DuckDBConnectionConfig(
83+
database=database,
84+
concurrent_tasks=2,
85+
secrets={"s3": {"type": "s3", "region": "us-east-1", "key_id": "foo", "secret": "bar"}},
86+
)
87+
88+
adapter = config.create_engine_adapter()
89+
pool = adapter._connection_pool
90+
91+
assert isinstance(pool, ThreadLocalSharedConnectionPool)
92+
93+
def _open_connection() -> bool:
94+
# this triggers cursor_init() to be run again for the new connection from the new thread
95+
# if the operations in cursor_init() are not idempotent, DuckDB will throw an error and this test will fail
96+
cur = pool.get_cursor()
97+
cur.execute("SELECT name FROM duckdb_secrets()")
98+
secret_names = [name for name_row in cur.fetchall() for name in name_row]
99+
assert secret_names == ["s3"]
100+
return True
101+
102+
thread_pool = ThreadPoolExecutor(max_workers=4)
103+
futures = []
104+
for _ in range(10):
105+
futures.append(thread_pool.submit(_open_connection))
106+
107+
for future in as_completed(futures):
108+
assert future.result()

tests/core/test_connection_config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -489,21 +489,23 @@ def test_duckdb_multiple_secrets(mock_connect, make_config):
489489
cursor = config.create_engine_adapter().cursor
490490

491491
execute_calls = [call[0][0] for call in mock_cursor.execute.call_args_list]
492-
create_secret_calls = [call for call in execute_calls if call.startswith("CREATE SECRET")]
492+
create_secret_calls = [
493+
call for call in execute_calls if call.startswith("CREATE OR REPLACE SECRET")
494+
]
493495

494496
# Should have exactly 2 CREATE SECRET calls
495497
assert len(create_secret_calls) == 2
496498

497499
# Verify the SQL for the first secret (S3)
498500
assert (
499501
create_secret_calls[0]
500-
== "CREATE SECRET (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
502+
== "CREATE OR REPLACE SECRET (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
501503
)
502504

503505
# Verify the SQL for the second secret (Azure)
504506
assert (
505507
create_secret_calls[1]
506-
== "CREATE SECRET (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
508+
== "CREATE OR REPLACE SECRET (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
507509
)
508510

509511

@@ -541,21 +543,23 @@ def test_duckdb_named_secrets(mock_connect, make_config):
541543
cursor = config.create_engine_adapter().cursor
542544

543545
execute_calls = [call[0][0] for call in mock_cursor.execute.call_args_list]
544-
create_secret_calls = [call for call in execute_calls if call.startswith("CREATE SECRET")]
546+
create_secret_calls = [
547+
call for call in execute_calls if call.startswith("CREATE OR REPLACE SECRET")
548+
]
545549

546550
# Should have exactly 2 CREATE SECRET calls
547551
assert len(create_secret_calls) == 2
548552

549553
# Verify the SQL for the first secret (S3) includes the secret name
550554
assert (
551555
create_secret_calls[0]
552-
== "CREATE SECRET my_s3_secret (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
556+
== "CREATE OR REPLACE SECRET my_s3_secret (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
553557
)
554558

555559
# Verify the SQL for the second secret (Azure) includes the secret name
556560
assert (
557561
create_secret_calls[1]
558-
== "CREATE SECRET my_azure_secret (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
562+
== "CREATE OR REPLACE SECRET my_azure_secret (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
559563
)
560564

561565

0 commit comments

Comments
 (0)