Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def init(cursor: duckdb.DuckDBPyConnection) -> None:
if secret_settings:
secret_clause = ", ".join(secret_settings)
try:
cursor.execute(f"CREATE SECRET {secret_name} ({secret_clause});")
cursor.execute(
f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});"
)
except Exception as e:
raise ConfigError(f"Failed to create secret: {e}")

Expand Down
36 changes: 35 additions & 1 deletion tests/core/engine_adapter/integration/test_integration_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from threading import current_thread, Thread
import random
from sqlglot import exp
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

from sqlmesh.core.config.connection import DuckDBConnectionConfig
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
Expand All @@ -11,7 +13,7 @@


@pytest.mark.parametrize("database", [None, "db.db"])
def test_multithread_concurrency(tmp_path, database: t.Optional[str]):
def test_multithread_concurrency(tmp_path: Path, database: t.Optional[str]):
num_threads = 100

if database:
Expand Down Expand Up @@ -72,3 +74,35 @@ def read_from_thread():

tables = adapter.fetchall("show tables")
assert len(tables) == num_threads + 1


def test_secret_registration_from_multiple_connections(tmp_path: Path):
database = str(tmp_path / "db.db")

config = DuckDBConnectionConfig(
database=database,
concurrent_tasks=2,
secrets={"s3": {"type": "s3", "region": "us-east-1", "key_id": "foo", "secret": "bar"}},
)

adapter = config.create_engine_adapter()
pool = adapter._connection_pool

assert isinstance(pool, ThreadLocalSharedConnectionPool)

def _open_connection() -> bool:
# this triggers cursor_init() to be run again for the new connection from the new thread
# if the operations in cursor_init() are not idempotent, DuckDB will throw an error and this test will fail
cur = pool.get_cursor()
cur.execute("SELECT name FROM duckdb_secrets()")
secret_names = [name for name_row in cur.fetchall() for name in name_row]
assert secret_names == ["s3"]
return True

thread_pool = ThreadPoolExecutor(max_workers=4)
futures = []
for _ in range(10):
futures.append(thread_pool.submit(_open_connection))

for future in as_completed(futures):
assert future.result()
16 changes: 10 additions & 6 deletions tests/core/test_connection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,21 +489,23 @@ def test_duckdb_multiple_secrets(mock_connect, make_config):
cursor = config.create_engine_adapter().cursor

execute_calls = [call[0][0] for call in mock_cursor.execute.call_args_list]
create_secret_calls = [call for call in execute_calls if call.startswith("CREATE SECRET")]
create_secret_calls = [
call for call in execute_calls if call.startswith("CREATE OR REPLACE SECRET")
]

# Should have exactly 2 CREATE SECRET calls
assert len(create_secret_calls) == 2

# Verify the SQL for the first secret (S3)
assert (
create_secret_calls[0]
== "CREATE SECRET (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
== "CREATE OR REPLACE SECRET (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
)

# Verify the SQL for the second secret (Azure)
assert (
create_secret_calls[1]
== "CREATE SECRET (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
== "CREATE OR REPLACE SECRET (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
)


Expand Down Expand Up @@ -541,21 +543,23 @@ def test_duckdb_named_secrets(mock_connect, make_config):
cursor = config.create_engine_adapter().cursor

execute_calls = [call[0][0] for call in mock_cursor.execute.call_args_list]
create_secret_calls = [call for call in execute_calls if call.startswith("CREATE SECRET")]
create_secret_calls = [
call for call in execute_calls if call.startswith("CREATE OR REPLACE SECRET")
]

# Should have exactly 2 CREATE SECRET calls
assert len(create_secret_calls) == 2

# Verify the SQL for the first secret (S3) includes the secret name
assert (
create_secret_calls[0]
== "CREATE SECRET my_s3_secret (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
== "CREATE OR REPLACE SECRET my_s3_secret (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');"
)

# Verify the SQL for the second secret (Azure) includes the secret name
assert (
create_secret_calls[1]
== "CREATE SECRET my_azure_secret (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
== "CREATE OR REPLACE SECRET my_azure_secret (type 'azure', account_name 'myaccount', account_key 'myaccountkey');"
)


Expand Down