diff --git a/CHANGES.md b/CHANGES.md index 2fba45e..53a92a1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,13 @@ # Changelog ## Unreleased +- Added canonical [PostgreSQL client parameter `sslmode`], implementing + `sslmode=require` to connect to SSL-enabled CrateDB instances without + verifying the host name. The previous `ssl=true` parameter is flagged + for deprecation, therefore `sslmode` takes precedence while both + options coexist. + +[PostgreSQL client parameter `sslmode`]: https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION ## 2025/01/30 0.41.0 - Dependencies: Updated to `crate-2.0.0`, which uses `orjson` for JSON marshalling diff --git a/docs/conf.py b/docs/conf.py index 800339a..b5df476 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,7 +30,8 @@ linkcheck_anchors = True linkcheck_ignore = [ - r"https://github.com/crate/cratedb-examples/blob/main/by-language/python-sqlalchemy/.*" + r"https://github.com/crate/cratedb-examples/blob/main/by-language/python-sqlalchemy/.*", + r"https://realpython.com/", ] rst_prolog = """ diff --git a/docs/index.rst b/docs/index.rst index b95b3ff..a1efbe8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -104,7 +104,7 @@ Connect to `CrateDB Cloud`_. # Connect using SQLAlchemy Core. import sqlalchemy as sa - dburi = "crate://admin:@example.aks1.westeurope.azure.cratedb.net:4200?ssl=true" + dburi = "crate://admin:@example.aks1.westeurope.azure.cratedb.net:4200?sslmode=require" engine = sa.create_engine(dburi, echo=True) Load results into `pandas`_ DataFrame. diff --git a/docs/overview.rst b/docs/overview.rst index e0590d9..faae628 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -60,13 +60,13 @@ Here, ```` is the hostname or IP address of the CrateDB node and When authentication is needed, the credentials can be optionally supplied using ``:@``. For connecting to an SSL-secured HTTP endpoint, you -can add the query parameter ``?ssl=true`` to the database URI. +can add the query parameter ``?sslmode=require`` to the database URI. Example database URIs: - ``crate://localhost:4200`` - ``crate://crate-1.vm.example.com:4200`` -- ``crate://username:password@crate-2.vm.example.com:4200/?ssl=true`` +- ``crate://username:password@crate-2.vm.example.com:4200/?sslmode=require`` - ``crate://198.51.100.1:4200`` .. TIP:: @@ -154,11 +154,11 @@ Once you have an CrateDB ``engine`` set up, you can create and use an SQLAlchemy Connecting to CrateDB Cloud ........................... -Connecting to `CrateDB Cloud`_ works like this. Please note the ``?ssl=true`` +Connecting to `CrateDB Cloud`_ works like this. Please note the ``?sslmode=require`` query parameter at the end of the database URI. >>> import sqlalchemy as sa - >>> dburi = "crate://admin:@example.aks1.westeurope.azure.cratedb.net:4200?ssl=true" + >>> dburi = "crate://admin:@example.aks1.westeurope.azure.cratedb.net:4200?sslmode=require" >>> engine = sa.create_engine(dburi, echo=True) diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 90102a7..a5932c4 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -20,10 +20,12 @@ # software solely pursuant to the terms of the relevant commercial agreement. import logging +import warnings from datetime import date, datetime from sqlalchemy import types as sqltypes from sqlalchemy.engine import default, reflection +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.sql import functions from sqlalchemy.util import asbool, to_list @@ -34,6 +36,7 @@ ) from .sa_version import SA_1_4, SA_2_0, SA_VERSION from .type import FloatVector, ObjectArray, ObjectType +from .util import SSLMode TYPES_MAP = { "boolean": sqltypes.Boolean, @@ -226,12 +229,41 @@ def connect(self, host=None, port=None, *args, **kwargs): if "servers" in kwargs: server = kwargs.pop("servers") servers = to_list(server) - if servers: - use_ssl = asbool(kwargs.pop("ssl", False)) - if use_ssl: - servers = ["https://" + server for server in servers] - return self.dbapi.connect(servers=servers, **kwargs) - return self.dbapi.connect(**kwargs) + + # Process legacy SSL option `ssl`. + if "ssl" in kwargs: + warnings.warn( + "The `ssl=true` option will be deprecated, " + "please use `sslmode=require` going forward.", + DeprecationWarning, + stacklevel=2, + ) + use_ssl = asbool(kwargs.pop("ssl", False)) + + # Process new SSL option `sslmode`. + # Please consult https://www.postgresql.org/docs/18/libpq-connect.html. + if "sslmode" in kwargs: + try: + sslmode = SSLMode.parse(kwargs.pop("sslmode")) + except AttributeError as exc: + modes = ", ".join(SSLMode.modes) + raise SQLAlchemyError( + "`sslmode` parameter must be one of: {}".format(modes) + ) from exc + if sslmode < SSLMode.allow: + use_ssl = False + else: + use_ssl = True + if sslmode >= SSLMode.verify_ca: + kwargs["verify_ssl_cert"] = True + else: + kwargs["verify_ssl_cert"] = False + + if not servers: + servers = [self.dbapi.http.Client.default_server.replace("http://", "")] + if use_ssl: + servers = ["https://" + server for server in servers] + return self.dbapi.connect(servers=servers, **kwargs) def do_execute(self, cursor, statement, parameters, context=None): """ diff --git a/src/sqlalchemy_cratedb/util.py b/src/sqlalchemy_cratedb/util.py new file mode 100644 index 0000000..1c34b33 --- /dev/null +++ b/src/sqlalchemy_cratedb/util.py @@ -0,0 +1,27 @@ +import enum + +from sqlalchemy.util import classproperty + + +class SSLMode(enum.IntEnum): + """ + SSLMode class from asyncpg, with a little improvement. + https://github.com/MagicStack/asyncpg/blob/v0.31.0/asyncpg/connect_utils.py#L36-L48 + """ + + disable = 0 + allow = 1 + prefer = 2 + require = 3 + verify_ca = 4 + verify_full = 5 + + @classmethod + def parse(cls, sslmode): + if isinstance(sslmode, cls): + return sslmode + return getattr(cls, sslmode.replace("-", "_")) + + @classproperty + def modes(cls): + return [m.name.replace("_", "-") for m in cls] diff --git a/tests/connection_test.py b/tests/connection_test.py index 00adb25..b7567a1 100644 --- a/tests/connection_test.py +++ b/tests/connection_test.py @@ -18,25 +18,34 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. - +import contextlib +import warnings from unittest import TestCase +import pytest import sqlalchemy as sa -from sqlalchemy.exc import NoSuchModuleError +from sqlalchemy.exc import NoSuchModuleError, SQLAlchemyError + +from sqlalchemy_cratedb import SA_1_4, SA_VERSION +from tests.util import ExtraAssertions -class SqlAlchemyConnectionTest(TestCase): +class SqlAlchemyConnectionTest(TestCase, ExtraAssertions): def test_connection_server_uri_unknown_sa_plugin(self): with self.assertRaises(NoSuchModuleError): sa.create_engine("foobar://otherhost:19201") - def test_default_connection(self): + def test_connection_no_hostname_no_ssl(self): engine = sa.create_engine("crate://") - conn = engine.raw_connection() - self.assertEqual( - ">", repr(conn.driver_connection) - ) - conn.close() + servers = engine.raw_connection().driver_connection.client._active_servers + self.assertEqual(["http://127.0.0.1:4200"], servers) + engine.dispose() + + @pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Not supported by SQLAlchemy 1.3") + def test_connection_no_hostname_with_ssl(self): + engine = sa.create_engine("crate://?sslmode=require") + servers = engine.raw_connection().driver_connection.client._active_servers + self.assertEqual(["https://127.0.0.1:4200"], servers) engine.dispose() def test_connection_server_uri_http(self): @@ -48,15 +57,62 @@ def test_connection_server_uri_http(self): conn.close() engine.dispose() - def test_connection_server_uri_https(self): - engine = sa.create_engine("crate://otherhost:19201/?ssl=true") - conn = engine.raw_connection() - self.assertEqual( - ">", repr(conn.driver_connection) - ) - conn.close() + @contextlib.contextmanager + def verify_user_warning_about_ssl_deprecation(self): + """ + The `ssl=true` option was flagged for deprecation. Verify that. + """ + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + # Run workhorse body. + yield + + # Verify details of the deprecation warning. + self.assertEqual(len(w), 1) + self.assertIsSubclass(w[-1].category, DeprecationWarning) + self.assertIn( + "The `ssl=true` option will be deprecated, " + "please use `sslmode=require` going forward.", + str(w[-1].message), + ) + + def test_connection_server_uri_https_ssl_enabled(self): + with self.verify_user_warning_about_ssl_deprecation(): + engine = sa.create_engine("crate://otherhost:19201/?ssl=true") + servers = engine.raw_connection().driver_connection.client._active_servers + self.assertEqual(["https://otherhost:19201"], servers) + engine.dispose() + + def test_connection_server_uri_https_ssl_disabled(self): + with self.verify_user_warning_about_ssl_deprecation(): + engine = sa.create_engine("crate://otherhost:19201/?ssl=false") + servers = engine.raw_connection().driver_connection.client._active_servers + self.assertEqual(["http://otherhost:19201"], servers) engine.dispose() + def test_connection_server_uri_https_sslmode_enabled(self): + engine = sa.create_engine("crate://otherhost:19201/?sslmode=require") + servers = engine.raw_connection().driver_connection.client._active_servers + self.assertEqual(["https://otherhost:19201"], servers) + engine.dispose() + + def test_connection_server_uri_https_sslmode_disabled(self): + engine = sa.create_engine("crate://otherhost:19201/?sslmode=disable") + servers = engine.raw_connection().driver_connection.client._active_servers + self.assertEqual(["http://otherhost:19201"], servers) + engine.dispose() + + def test_connection_server_uri_https_sslmode_invalid(self): + with pytest.raises(SQLAlchemyError) as exc_info: + engine = sa.create_engine("crate://otherhost:19201/?sslmode=foo") + engine.raw_connection() + exc_info.match( + "`sslmode` parameter must be one of: " + "disable, allow, prefer, require, verify-ca, verify-full" + ) + def test_connection_server_uri_invalid_port(self): with self.assertRaises(ValueError) as context: sa.create_engine("crate://foo:bar")