Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Changelog

## Unreleased
- Added canonical [PostgreSQL client parameter `sslmode`], implementing
`sslmode=require` to connect to SSL-enabled CrateDB instances without
verifying the host name. The previous `ssl=true` parameter is flagged
for deprecation, therefore `sslmode` takes precedence while both
options coexist.

[PostgreSQL client parameter `sslmode`]: https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION

## 2025/01/30 0.41.0
- Dependencies: Updated to `crate-2.0.0`, which uses `orjson` for JSON marshalling
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

linkcheck_anchors = True
linkcheck_ignore = [
r"https://github.com/crate/cratedb-examples/blob/main/by-language/python-sqlalchemy/.*"
r"https://github.com/crate/cratedb-examples/blob/main/by-language/python-sqlalchemy/.*",
r"https://realpython.com/",
]

rst_prolog = """
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Connect to `CrateDB Cloud`_.
# Connect using SQLAlchemy Core.
import sqlalchemy as sa
dburi = "crate://admin:<PASSWORD>@example.aks1.westeurope.azure.cratedb.net:4200?ssl=true"
dburi = "crate://admin:<PASSWORD>@example.aks1.westeurope.azure.cratedb.net:4200?sslmode=require"
engine = sa.create_engine(dburi, echo=True)
Load results into `pandas`_ DataFrame.
Expand Down
8 changes: 4 additions & 4 deletions docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ Here, ``<HOST_ADDR>`` is the hostname or IP address of the CrateDB node and

When authentication is needed, the credentials can be optionally supplied using
``<USERNAME>:<PASSWORD>@``. For connecting to an SSL-secured HTTP endpoint, you
can add the query parameter ``?ssl=true`` to the database URI.
can add the query parameter ``?sslmode=require`` to the database URI.

Example database URIs:

- ``crate://localhost:4200``
- ``crate://crate-1.vm.example.com:4200``
- ``crate://username:password@crate-2.vm.example.com:4200/?ssl=true``
- ``crate://username:password@crate-2.vm.example.com:4200/?sslmode=require``
- ``crate://198.51.100.1:4200``

.. TIP::
Expand Down Expand Up @@ -154,11 +154,11 @@ Once you have an CrateDB ``engine`` set up, you can create and use an SQLAlchemy
Connecting to CrateDB Cloud
...........................

Connecting to `CrateDB Cloud`_ works like this. Please note the ``?ssl=true``
Connecting to `CrateDB Cloud`_ works like this. Please note the ``?sslmode=require``
query parameter at the end of the database URI.

>>> import sqlalchemy as sa
>>> dburi = "crate://admin:<PASSWORD>@example.aks1.westeurope.azure.cratedb.net:4200?ssl=true"
>>> dburi = "crate://admin:<PASSWORD>@example.aks1.westeurope.azure.cratedb.net:4200?sslmode=require"
>>> engine = sa.create_engine(dburi, echo=True)


Expand Down
44 changes: 38 additions & 6 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
# software solely pursuant to the terms of the relevant commercial agreement.

import logging
import warnings
from datetime import date, datetime

from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql import functions
from sqlalchemy.util import asbool, to_list

Expand All @@ -34,6 +36,7 @@
)
from .sa_version import SA_1_4, SA_2_0, SA_VERSION
from .type import FloatVector, ObjectArray, ObjectType
from .util import SSLMode

TYPES_MAP = {
"boolean": sqltypes.Boolean,
Expand Down Expand Up @@ -226,12 +229,41 @@ def connect(self, host=None, port=None, *args, **kwargs):
if "servers" in kwargs:
server = kwargs.pop("servers")
servers = to_list(server)
if servers:
use_ssl = asbool(kwargs.pop("ssl", False))
if use_ssl:
servers = ["https://" + server for server in servers]
return self.dbapi.connect(servers=servers, **kwargs)
return self.dbapi.connect(**kwargs)

# Process legacy SSL option `ssl`.
if "ssl" in kwargs:
warnings.warn(
"The `ssl=true` option will be deprecated, "
"please use `sslmode=require` going forward.",
DeprecationWarning,
stacklevel=2,
)
use_ssl = asbool(kwargs.pop("ssl", False))

# Process new SSL option `sslmode`.
# Please consult https://www.postgresql.org/docs/18/libpq-connect.html.
if "sslmode" in kwargs:
try:
sslmode = SSLMode.parse(kwargs.pop("sslmode"))
except AttributeError as exc:
modes = ", ".join(SSLMode.modes)
raise SQLAlchemyError(
"`sslmode` parameter must be one of: {}".format(modes)
) from exc
if sslmode < SSLMode.allow:
use_ssl = False
else:
use_ssl = True
if sslmode >= SSLMode.verify_ca:
kwargs["verify_ssl_cert"] = True
else:
kwargs["verify_ssl_cert"] = False

if not servers:
servers = [self.dbapi.http.Client.default_server.replace("http://", "")]
if use_ssl:
servers = ["https://" + server for server in servers]
return self.dbapi.connect(servers=servers, **kwargs)

def do_execute(self, cursor, statement, parameters, context=None):
"""
Expand Down
27 changes: 27 additions & 0 deletions src/sqlalchemy_cratedb/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import enum

from sqlalchemy.util import classproperty


class SSLMode(enum.IntEnum):
"""
SSLMode class from asyncpg, with a little improvement.
https://github.com/MagicStack/asyncpg/blob/v0.31.0/asyncpg/connect_utils.py#L36-L48
"""

disable = 0
allow = 1
prefer = 2
require = 3
verify_ca = 4
verify_full = 5

@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace("-", "_"))

@classproperty
def modes(cls):
return [m.name.replace("_", "-") for m in cls]
88 changes: 72 additions & 16 deletions tests/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,34 @@
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

import contextlib
import warnings
from unittest import TestCase

import pytest
import sqlalchemy as sa
from sqlalchemy.exc import NoSuchModuleError
from sqlalchemy.exc import NoSuchModuleError, SQLAlchemyError

from sqlalchemy_cratedb import SA_1_4, SA_VERSION
from tests.util import ExtraAssertions


class SqlAlchemyConnectionTest(TestCase):
class SqlAlchemyConnectionTest(TestCase, ExtraAssertions):
def test_connection_server_uri_unknown_sa_plugin(self):
with self.assertRaises(NoSuchModuleError):
sa.create_engine("foobar://otherhost:19201")

def test_default_connection(self):
def test_connection_no_hostname_no_ssl(self):
engine = sa.create_engine("crate://")
conn = engine.raw_connection()
self.assertEqual(
"<Connection <Client ['http://127.0.0.1:4200']>>", repr(conn.driver_connection)
)
conn.close()
servers = engine.raw_connection().driver_connection.client._active_servers
self.assertEqual(["http://127.0.0.1:4200"], servers)
engine.dispose()

@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Not supported by SQLAlchemy 1.3")
def test_connection_no_hostname_with_ssl(self):
engine = sa.create_engine("crate://?sslmode=require")
servers = engine.raw_connection().driver_connection.client._active_servers
self.assertEqual(["https://127.0.0.1:4200"], servers)
engine.dispose()

def test_connection_server_uri_http(self):
Expand All @@ -48,15 +57,62 @@ def test_connection_server_uri_http(self):
conn.close()
engine.dispose()

def test_connection_server_uri_https(self):
engine = sa.create_engine("crate://otherhost:19201/?ssl=true")
conn = engine.raw_connection()
self.assertEqual(
"<Connection <Client ['https://otherhost:19201']>>", repr(conn.driver_connection)
)
conn.close()
@contextlib.contextmanager
def verify_user_warning_about_ssl_deprecation(self):
"""
The `ssl=true` option was flagged for deprecation. Verify that.
"""
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")

# Run workhorse body.
yield

# Verify details of the deprecation warning.
self.assertEqual(len(w), 1)
self.assertIsSubclass(w[-1].category, DeprecationWarning)
self.assertIn(
"The `ssl=true` option will be deprecated, "
"please use `sslmode=require` going forward.",
str(w[-1].message),
)

def test_connection_server_uri_https_ssl_enabled(self):
with self.verify_user_warning_about_ssl_deprecation():
engine = sa.create_engine("crate://otherhost:19201/?ssl=true")
servers = engine.raw_connection().driver_connection.client._active_servers
self.assertEqual(["https://otherhost:19201"], servers)
engine.dispose()

def test_connection_server_uri_https_ssl_disabled(self):
with self.verify_user_warning_about_ssl_deprecation():
engine = sa.create_engine("crate://otherhost:19201/?ssl=false")
servers = engine.raw_connection().driver_connection.client._active_servers
self.assertEqual(["http://otherhost:19201"], servers)
engine.dispose()

def test_connection_server_uri_https_sslmode_enabled(self):
engine = sa.create_engine("crate://otherhost:19201/?sslmode=require")
servers = engine.raw_connection().driver_connection.client._active_servers
self.assertEqual(["https://otherhost:19201"], servers)
engine.dispose()

def test_connection_server_uri_https_sslmode_disabled(self):
engine = sa.create_engine("crate://otherhost:19201/?sslmode=disable")
servers = engine.raw_connection().driver_connection.client._active_servers
self.assertEqual(["http://otherhost:19201"], servers)
engine.dispose()

def test_connection_server_uri_https_sslmode_invalid(self):
with pytest.raises(SQLAlchemyError) as exc_info:
engine = sa.create_engine("crate://otherhost:19201/?sslmode=foo")
engine.raw_connection()
exc_info.match(
"`sslmode` parameter must be one of: "
"disable, allow, prefer, require, verify-ca, verify-full"
)

def test_connection_server_uri_invalid_port(self):
with self.assertRaises(ValueError) as context:
sa.create_engine("crate://foo:bar")
Expand Down
Loading