Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying
[KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions,
you can use it like `FloatVector(dimensions=1536)`.
- Fixed `get_table_names()` reflection method to respect the
`schema` query argument in SQLAlchemy connection URLs.

[FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
[KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ release = [
"twine<6",
]
test = [
"cratedb-toolkit[testing]",
"dask[dataframe]",
"pandas<2.3",
"pueblo>=0.0.7",
Expand Down
11 changes: 11 additions & 0 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ def connect(self, host=None, port=None, *args, **kwargs):
def _get_default_schema_name(self, connection):
return 'doc'

def _get_effective_schema_name(self, connection):
schema_name_raw = connection.engine.url.query.get("schema")
schema_name = None
if isinstance(schema_name_raw, str):
schema_name = schema_name_raw
elif isinstance(schema_name_raw, tuple):
schema_name = schema_name_raw[0]
return schema_name

def _get_server_version_info(self, connection):
return tuple(connection.connection.lowest_server_version.version)

Expand Down Expand Up @@ -258,6 +267,8 @@ def get_schema_names(self, connection, **kw):

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self._get_effective_schema_name(connection)
cursor = connection.exec_driver_sql(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
Expand Down
4 changes: 1 addition & 3 deletions tests/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.
import sys
import warnings
from textwrap import dedent
from unittest import mock, skipIf, TestCase
Expand Down Expand Up @@ -289,8 +288,7 @@ def test_for_update(self):
FakeCursor = MagicMock(name='FakeCursor', spec=Cursor)


@skipIf(SA_VERSION < SA_1_4 and (3, 9) <= sys.version_info < (3, 10),
"SQLAlchemy 1.3 has problems with these test cases on Python 3.9")
@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases")
class CompilerTestCase(TestCase):
"""
A base class for providing mocking infrastructure to validate the DDL compiler.
Expand Down
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2021-2023, Crate.io Inc.
# Distributed under the terms of the AGPLv3 license, see LICENSE.
import pytest
from cratedb_toolkit.testing.testcontainers.cratedb import CrateDBTestAdapter

# Use different schemas for storing the subsystem database tables, and the
# test/example data, so that they do not accidentally touch the default `doc`
# schema.
TESTDRIVE_EXT_SCHEMA = "testdrive-ext"
TESTDRIVE_DATA_SCHEMA = "testdrive-data"


@pytest.fixture(scope="session")
def cratedb_service():
"""
Provide a CrateDB service instance to the test suite.
"""
db = CrateDBTestAdapter()
db.start()
yield db
db.stop()
7 changes: 6 additions & 1 deletion tests/datetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
# software solely pursuant to the terms of the relevant commercial agreement.

from __future__ import absolute_import

from datetime import datetime, tzinfo, timedelta
from unittest import TestCase
from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock

import sqlalchemy as sa
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session

from sqlalchemy_cratedb import SA_VERSION, SA_1_4

try:
from sqlalchemy.orm import declarative_base
except ImportError:
Expand All @@ -52,6 +56,7 @@ def dst(self, date_time):
return timedelta(seconds=-7200)


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases")
@patch('crate.client.connection.Cursor', FakeCursor)
class SqlAlchemyDateAndDateTimeTest(TestCase):

Expand Down
6 changes: 4 additions & 2 deletions tests/dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# software solely pursuant to the terms of the relevant commercial agreement.

from __future__ import absolute_import
from unittest import TestCase

from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock

import sqlalchemy as sa
Expand All @@ -31,7 +32,7 @@
except ImportError:
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_cratedb import ObjectArray, ObjectType
from sqlalchemy_cratedb import ObjectArray, ObjectType, SA_VERSION, SA_1_4
from crate.client.cursor import Cursor


Expand All @@ -40,6 +41,7 @@
FakeCursor.return_value = fake_cursor


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases")
class SqlAlchemyDictTypeTest(TestCase):

def setUp(self):
Expand Down
7 changes: 5 additions & 2 deletions tests/insert_from_select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

from datetime import datetime
from unittest import TestCase
from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock

import sqlalchemy as sa
from sqlalchemy import select, insert
from sqlalchemy.orm import Session

from sqlalchemy_cratedb import SA_VERSION, SA_1_4

try:
from sqlalchemy.orm import declarative_base
except ImportError:
Expand All @@ -40,6 +42,7 @@
FakeCursor.return_value = fake_cursor


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases")
class SqlAlchemyInsertFromSelectTest(TestCase):

def assertSQL(self, expected_str, actual_expr):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sqlalchemy as sa

from tests.conftest import TESTDRIVE_DATA_SCHEMA


def test_correct_schema(cratedb_service):
"""
Tests that the correct schema is being picked up.
"""
database = cratedb_service.database

tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"'
inspector: sa.Inspector = sa.inspect(database.engine)
database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1")

assert TESTDRIVE_DATA_SCHEMA in inspector.get_schema_names()

table_names = inspector.get_table_names(schema=TESTDRIVE_DATA_SCHEMA)
assert table_names == ["foobar"]

view_names = inspector.get_view_names(schema=TESTDRIVE_DATA_SCHEMA)
assert view_names == []

indexes = inspector.get_indexes(tablename)
assert indexes == []
6 changes: 3 additions & 3 deletions tests/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

from datetime import datetime
from unittest import TestCase
from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock

from sqlalchemy_cratedb import ObjectType
from sqlalchemy_cratedb import ObjectType, SA_VERSION, SA_1_4

import sqlalchemy as sa
from sqlalchemy.orm import Session
Expand All @@ -41,6 +40,7 @@
FakeCursor.return_value = fake_cursor


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases")
class SqlAlchemyUpdateTest(TestCase):

def setUp(self):
Expand Down