diff --git a/codebeaver.yml b/codebeaver.yml
new file mode 100644
index 0000000..419e243
--- /dev/null
+++ b/codebeaver.yml
@@ -0,0 +1,2 @@
+from: pytest
+# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
\ No newline at end of file
diff --git a/test/test_table.py b/test/test_table.py
new file mode 100644
index 0000000..375bd9a
--- /dev/null
+++ b/test/test_table.py
@@ -0,0 +1,571 @@
+import os
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from sqlalchemy.types import BIGINT, INTEGER, VARCHAR, TEXT
+from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
+from sqlalchemy.sql.expression import false
+
+from dataset import chunked
+
+from .conftest import TEST_DATA, TEST_CITY_1
+
+
+def test_insert(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ last_id = table.insert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_insert_ignore(table):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_ignore_all_key(table):
+ for i in range(0, 4):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_json(table):
+ last_id = table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "place": "Berlin",
+ "info": {
+ "currency": "EUR",
+ "language": "German",
+ "population": 3292365,
+ },
+ }
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_upsert(table):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_single_column(db):
+ table = db["banana_single_col"]
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+
+
+def test_upsert_all_key(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ for i in range(0, 2):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_id(db):
+ table = db["banana_with_id"]
+ data = dict(id=10, title="I am a banana!")
+ table.upsert(data, ["id"])
+ assert len(table) == 1, len(table)
+
+
+def test_update_while_iter(table):
+ for row in table:
+ row["foo"] = "bar"
+ table.update(row, ["place", "date"])
+ assert len(table) == len(TEST_DATA), len(table)
+
+
+def test_weird_column_names(table):
+ with pytest.raises(ValueError):
+ table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "foo.bar": "Berlin",
+ "qux.bar": "Huhu",
+ }
+ )
+
+
+def test_cased_column_names(db):
+ tbl = db["cased_column_names"]
+ tbl.insert({"place": "Berlin"})
+ tbl.insert({"Place": "Berlin"})
+ tbl.insert({"PLACE ": "Berlin"})
+ assert len(tbl.columns) == 2, tbl.columns
+ assert len(list(tbl.find(Place="Berlin"))) == 3
+ assert len(list(tbl.find(place="Berlin"))) == 3
+ assert len(list(tbl.find(PLACE="Berlin"))) == 3
+
+
+def test_invalid_column_names(db):
+ tbl = db["weather"]
+ with pytest.raises(ValueError):
+ tbl.insert({None: "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"": "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"-": "banana"})
+
+
+def test_delete(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ original_count = len(table)
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ # Test bad use of API
+ with pytest.raises(ArgumentError):
+ table.delete({"place": "Berlin"})
+ assert len(table) == original_count, len(table)
+
+ assert table.delete(place="Berlin") is True, "should return 1"
+ assert len(table) == len(TEST_DATA), len(table)
+ assert table.delete() is True, "should return non zero"
+ assert len(table) == 0, len(table)
+
+
+def test_repr(table):
+ assert (
+ repr(table) == "
"
+ ), "the representation should be "
+
+
+def test_delete_nonexist_entry(table):
+ assert (
+ table.delete(place="Berlin") is False
+ ), "entry not exist, should fail to delete"
+
+
+def test_find_one(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ d = table.find_one(place="Berlin")
+ assert d["temperature"] == -10, d
+ d = table.find_one(place="Atlantis")
+ assert d is None, d
+
+
+def test_count(table):
+ assert len(table) == 6, len(table)
+ length = table.count(place=TEST_CITY_1)
+ assert length == 3, length
+
+
+def test_find(table):
+ ds = list(table.find(place=TEST_CITY_1))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _step=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=1, _step=2))
+ assert len(ds) == 1, ds
+ ds = list(table.find(_step=2))
+ assert len(ds) == len(TEST_DATA), ds
+ ds = list(table.find(order_by=["temperature"]))
+ assert ds[0]["temperature"] == -1, ds
+ ds = list(table.find(order_by=["-temperature"]))
+ assert ds[0]["temperature"] == 8, ds
+ ds = list(table.find(table.table.columns.temperature > 4))
+ assert len(ds) == 3, ds
+
+
+def test_find_dsl(table):
+ ds = list(table.find(place={"like": "%lw%"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={">": 5}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={">=": 5}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={"<": 0}))
+ assert len(ds) == 1, ds
+ ds = list(table.find(temperature={"<=": 0}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={"!=": -1}))
+ assert len(ds) == 5, ds
+ ds = list(table.find(temperature={"between": [5, 8]}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"=": "G€lway"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"ilike": "%LwAy"}))
+ assert len(ds) == 3, ds
+
+
+def test_offset(table):
+ ds = list(table.find(place=TEST_CITY_1, _offset=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _offset=2))
+ assert len(ds) == 1, ds
+
+
+def test_streamed_update(table):
+ ds = list(table.find(place=TEST_CITY_1, _streamed=True, _step=1))
+ assert len(ds) == 3, len(ds)
+ for row in table.find(place=TEST_CITY_1, _streamed=True, _step=1):
+ row["temperature"] = -1
+ table.update(row, ["id"])
+
+
+def test_distinct(table):
+ x = list(table.distinct("place"))
+ assert len(x) == 2, x
+ x = list(table.distinct("place", "date"))
+ assert len(x) == 6, x
+ x = list(
+ table.distinct(
+ "place",
+ "date",
+ table.table.columns.date >= datetime(2011, 1, 2, 0, 0),
+ )
+ )
+ assert len(x) == 4, x
+
+ x = list(table.distinct("temperature", place="B€rkeley"))
+ assert len(x) == 3, x
+ x = list(table.distinct("temperature", place=["B€rkeley", "G€lway"]))
+ assert len(x) == 6, x
+
+
+def test_insert_many(table):
+ data = TEST_DATA * 100
+ table.insert_many(data, chunk_size=13)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert(table):
+ data = TEST_DATA * 100
+ with chunked.ChunkedInsert(table) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert_callback(table):
+ data = TEST_DATA * 100
+ N = 0
+
+ def callback(queue):
+ nonlocal N
+ N += len(queue)
+
+ with chunked.ChunkedInsert(table, callback=callback) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(data) == N
+ assert len(table) == len(data) + 6
+
+
+def test_update_many(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)])
+ tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id")
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"]
+
+
+def test_chunked_update(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many(
+ [
+ dict(temp=10, location="asdf"),
+ dict(temp=20, location="qwer"),
+ dict(temp=30, location="asdf"),
+ ]
+ )
+ db.commit()
+
+ chunked_tbl = chunked.ChunkedUpdate(tbl, ["id"])
+ chunked_tbl.update(dict(id=1, temp=50))
+ chunked_tbl.update(dict(id=2, location="asdf"))
+ chunked_tbl.update(dict(id=3, temp=50))
+ chunked_tbl.flush()
+ db.commit()
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50
+ assert tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf"
+
+
+def test_upsert_many(db):
+ # Also tests updating on records with different attributes
+ tbl = db["upsert_many_test"]
+
+ W = 100
+ tbl.upsert_many([dict(age=10), dict(weight=W)], "id")
+ assert tbl.find_one(id=1)["age"] == 10
+
+ tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id")
+ assert tbl.find_one(id=2)["weight"] == W / 2
+
+
+def test_drop_operations(table):
+ assert table._table is not None, "table shouldn't be dropped yet"
+ table.drop()
+ assert table._table is None, "table should be dropped now"
+ assert list(table.all()) == [], table.all()
+ assert table.count() == 0, table.count()
+
+
+def test_table_drop(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+
+
+def test_table_drop_then_create(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+ db["weather"].insert({"foo": "bar"})
+
+
+def test_columns(table):
+ cols = table.columns
+ assert len(list(cols)) == 4, "column count mismatch"
+ assert "date" in cols and "temperature" in cols and "place" in cols
+
+
+def test_drop_column(table):
+ try:
+ table.drop_column("date")
+ assert "date" not in table.columns
+ except RuntimeError:
+ pass
+
+
+def test_iter(table):
+ c = 0
+ for row in table:
+ c += 1
+ assert c == len(table)
+
+
+def test_update(table):
+ date = datetime(2011, 1, 2)
+ res = table.update(
+ {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"]
+ )
+ assert res, "update should return True"
+ m = table.find_one(place=TEST_CITY_1, date=date)
+ assert m["temperature"] == -10, (
+ "new temp. should be -10 but is %d" % m["temperature"]
+ )
+
+
+def test_create_column(db, table):
+ flt = db.types.float
+ table.create_column("foo", flt)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["foo"].type
+ assert "foo" in table.columns, table.columns
+
+
+def test_ensure_column(db, table):
+ flt = db.types.float
+ table.create_column_by_example("foo", 0.1)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["bar"].type
+ table.create_column_by_example("bar", 1)
+ assert "bar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bar"].type, BIGINT), table.table.c["bar"].type
+ table.create_column_by_example("pippo", "test")
+ assert "pippo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["pippo"].type, TEXT), table.table.c["pippo"].type
+ table.create_column_by_example("bigbar", 11111111111)
+ assert "bigbar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bigbar"].type, BIGINT), table.table.c[
+ "bigbar"
+ ].type
+ table.create_column_by_example("littlebar", -11111111111)
+ assert "littlebar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["littlebar"].type, BIGINT), table.table.c[
+ "littlebar"
+ ].type
+
+
+def test_key_order(db, table):
+ res = db.query("SELECT temperature, place FROM weather LIMIT 1")
+ keys = list(res.next().keys())
+ assert keys[0] == "temperature"
+ assert keys[1] == "place"
+
+
+def test_empty_query(table):
+ empty = list(table.find(place="not in data"))
+ assert len(empty) == 0, empty
+
+def test_find_unknown_operator(table):
+ """Test find with an unknown operator returns no results."""
+ # Insert a distinct row below to guarantee a match is possible if operator was recognized
+ table.insert({"date": datetime(2020, 1, 1), "temperature": 20, "place": "Test"})
+ # Using an unknown operator "foo" should yield a clause that always evaluates to false.
+ results = list(table.find(temperature={"foo": 20}))
+ assert len(results) == 0, "Unknown operator should result in no rows matching"
+
+def test_create_index_existing(table):
+ """Test that creating an index twice does not fail and is handled gracefully."""
+ # Create an index on "place"
+ table.create_index("place")
+ # Attempt again to create the same index
+ table.create_index("place")
+ # Verify that has_index returns True.
+ assert table.has_index("place")
+
+def test_drop_column_sqlite(db, table):
+ """Test that dropping a column on SQLite raises a RuntimeError."""
+ original_is_sqlite = db.is_sqlite
+ db.is_sqlite = True
+ try:
+ with pytest.raises(RuntimeError):
+ table.drop_column("temperature")
+ finally:
+ db.is_sqlite = original_is_sqlite
+
+def test_sync_columns_without_ensure(table):
+ """Test that _sync_columns does not add new keys when ensure is False."""
+ # Capture the current columns.
+ original_columns = set(table.columns)
+ # Create a row with a non-existing column.
+ new_row = {"non_existing_col": "value", "place": "Berlin", "date": datetime(2011, 1, 2)}
+ synced = table._sync_columns(new_row, ensure=False)
+ # "non_existing_col" should not be part of the synced row.
+ assert "non_existing_col" not in synced
+ # Existing columns should remain.
+ assert "place" in synced and "date" in synced
+
+def test_keys_to_args_missing_keys(table):
+ """Test _keys_to_args when the row is missing some keys."""
+ row = {"date": datetime(2011, 1, 2), "temperature": -10}
+ keys = ["date", "place"]
+ args, remaining = table._keys_to_args(row, keys)
+ # "date" exists, but "place" does not.
+ assert args["date"] is not None
+ assert args["place"] is None
+ # The remaining row should contain only keys not in "keys"
+ assert "temperature" in remaining and len(remaining) == 1
+
+def test_generate_clause_edge(table):
+ """Test _generate_clause with an unknown operator returns a false clause."""
+ clause = table._generate_clause("place", "unknown", "value")
+ # Converting to string to compare with false() clause.
+ assert str(clause) == str(false())
+
+def test_update_with_missing_filter_keys(table):
+ """Test update returns zero count when no rows match the filter keys."""
+ # Attempt to update a non-existent row (using "id" as filter)
+ result = table.update({"date": datetime(2022, 2, 2), "temperature": 5, "place": "NonExistent"}, ["id"])
+ # Since no row should match, the update should return a rowcount of 0 (or possibly None).
+ assert result == 0 or result is None
+
+# End of added tests
+def test_threading_warning(table, monkeypatch):
+ """Test that a RuntimeWarning is triggered when changing the schema during a transaction in a multi-threaded context."""
+ # Simulate being in a transaction and having multiple active threads.
+ monkeypatch.setattr(type(table.db), "in_transaction", property(lambda self: True))
+ monkeypatch.setattr("threading.active_count", lambda: 2)
+ with pytest.warns(RuntimeWarning, match="Changing the database schema inside a transaction"):
+ table._threading_warn()
+
+def test_sync_columns_with_ensure(table):
+ """Test that _sync_columns adds non-existing keys when ensure=True and updates the table schema."""
+ old_columns = set(table.columns)
+ new_row = {"date": datetime(2011, 1, 2), "place": "Berlin", "extra": "value"}
+ result = table._sync_columns(new_row, ensure=True)
+ # Check that the new 'extra' column is added in returned row and table.columns.
+ assert "extra" in result
+ assert "extra" in table.columns
+
+def test_update_no_change(table):
+ """Test update returns the number of matching rows if no new value is provided after filter keys are removed."""
+ # Insert a sample row.
+ table.insert({"date": datetime(2011, 1, 2), "temperature": 20, "place": "NoChange"})
+ count_before = table.count(place="NoChange")
+ # Provide an update where the row dictionary only contains the filter keys.
+ res = table.update({"date": datetime(2011, 1, 2), "place": "NoChange"}, ["date", "place"])
+ # Since there are no new column values to update, update() should return the count of matching rows.
+ assert res == count_before
+
+def test_upsert_many_empty(db):
+ """Test that calling upsert_many with an empty list performs no operations and leaves the table empty."""
+ tbl = db["upsert_many_empty_test"]
+ tbl.upsert_many([], "id")
+ assert tbl.count() == 0
+def test_generate_clause_startswith(table):
+ """Test _generate_clause with 'startswith' operator returns proper LIKE clause."""
+ clause = table._generate_clause("place", "startswith", "Ber")
+ compiled = clause.compile(compile_kwargs={"literal_binds": True})
+ clause_str = str(compiled)
+ assert "LIKE" in clause_str and "Ber%" in clause_str
+
+def test_generate_clause_endswith(table):
+ """Test _generate_clause with 'endswith' operator returns proper LIKE clause."""
+ clause = table._generate_clause("place", "endswith", "lin")
+ compiled = clause.compile(compile_kwargs={"literal_binds": True})
+ clause_str = str(compiled)
+ assert "LIKE" in clause_str and "%lin" in clause_str
+
+def test_generate_clause_in(table):
+ """Test _generate_clause with 'in' operator returns proper IN clause."""
+ clause = table._generate_clause("temperature", "in", [1, 2, 3])
+ compiled = clause.compile(compile_kwargs={"literal_binds": True})
+ clause_str = str(compiled)
+ assert "IN" in clause_str
+
+def test_generate_clause_between(table):
+ """Test _generate_clause with 'between' operator returns proper BETWEEN clause."""
+ clause = table._generate_clause("temperature", "between", (5, 8))
+ compiled = clause.compile(compile_kwargs={"literal_binds": True})
+ clause_str = str(compiled)
+ assert "BETWEEN" in clause_str
+
+def test_column_keys_after_insert(table):
+ """Test that inserting a row with ensure=True adds a new column to the table schema and _column_keys."""
+ table.insert({"date": datetime(2020, 1, 1), "temperature": 0, "place": "TestCase", "new_column": "new_value"}, ensure=True)
+ table.insert({"date": datetime(2020, 1, 1), "temperature": 0, "place": "TestCase", "new_column": "new_value"}, ensure=True)
+ # Force reloading of columns by accessing the _column_keys property.
+ new_keys = set(table._column_keys.values())
+ assert "new_column" in new_keys
+ # Also check that the inserted row contains the new column.
+ row = table.find_one(new_column="new_value")
+ assert row is not None and row.get("new_column") == "new_value"
+
+def test_multiple_index_creation(table):
+ """Test creation of a multi-column index and verify that has_index returns True."""
+ # Create an index on 'date' and 'place'
+ table.create_index(["date", "place"], name="test_multi_idx")
+ assert table.has_index(["date", "place"])
+
+def test_update_no_matching_row(table):
+ """Test that calling update with filter keys that yield no matching row returns zero (or None)."""
+ res = table.update({"date": "2099-01-01", "temperature": 100, "place": "Nowhere"}, ["date", "place"])
+ assert res == 0 or res is None
\ No newline at end of file
diff --git a/tests/test_setup.py b/tests/test_setup.py
new file mode 100644
index 0000000..99a020a
--- /dev/null
+++ b/tests/test_setup.py
@@ -0,0 +1,139 @@
+import io
+import builtins
+import pytest
+import importlib
+import setuptools
+
+# Global dictionary to capture the kwargs passed to setuptools.setup
+captured_setup_kwargs = {}
+
+def dummy_setup(*args, **kwargs):
+ """Dummy setup function that captures the setup kwargs."""
+ captured_setup_kwargs.update(kwargs)
+
+@pytest.fixture(autouse=True, scope="function")
+def patch_setup(monkeypatch):
+ """Patch setuptools.setup and builtins.open before importing setup.py."""
+ # Patch setuptools.setup with our dummy function
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+
+ # Patch open so that opening "README.md" returns a dummy stream with controlled text
+ def dummy_open(filename, *args, **kwargs):
+ if filename == "README.md":
+ return io.StringIO("Dummy long description")
+ return open(filename, *args, **kwargs)
+ monkeypatch.setattr(builtins, "open", dummy_open)
+
+ # Import (or reload) the setup module to trigger its top‐level code
+ import setup
+ importlib.reload(setup)
+
+def test_setup_configuration():
+ """Test that setup.py properly calls setuptools.setup with the expected configuration."""
+ # Verify the package name and version
+ assert captured_setup_kwargs.get("name") == "dataset", "Package name should be 'dataset'"
+ assert captured_setup_kwargs.get("version") == "1.6.0", "Version should be '1.6.0'"
+
+ # Verify that the long_description is correctly populated from our dummy README.md
+ assert captured_setup_kwargs.get("long_description") == "Dummy long description", \
+ "long_description should match dummy content from README.md"
+
+ # Verify that the install_requires includes the expected packages
+ install_requires = captured_setup_kwargs.get("install_requires", [])
+ assert "sqlalchemy >= 2.0.15, < 3.0.0" in install_requires, \
+ "install_requires should include sqlalchemy version range"
+ assert "alembic >= 1.11.1" in install_requires, \
+ "install_requires should include alembic"
+ assert "banal >= 1.0.1" in install_requires, \
+ "install_requires should include banal"
+
+ # Verify that extras_require and tests_require are configured
+ extras_require = captured_setup_kwargs.get("extras_require", {})
+ assert "dev" in extras_require, "extras_require should include the 'dev' option"
+
+ tests_require = captured_setup_kwargs.get("tests_require", [])
+ assert "pytest" in tests_require, "tests_require should include 'pytest'"
+def test_setup_additional_configuration():
+ """Test additional configuration parameters of setup.py."""
+ # Test the description field
+ assert captured_setup_kwargs.get("description") == "Toolkit for Python-based database access.", "Description mismatch"
+
+ # Test long_description_content_type
+ assert captured_setup_kwargs.get("long_description_content_type") == "text/markdown", "long_description_content_type mismatch"
+
+ # Test classifiers: ensure certain known classifiers are present and total count is 9
+ classifiers = captured_setup_kwargs.get("classifiers", [])
+ expected_classifiers = [
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ ]
+ for classifier in expected_classifiers:
+ assert classifier in classifiers, f"Classifier {classifier} not found"
+ assert len(classifiers) == 9, "There should be 9 classifiers"
+
+ # Test keywords, author, author_email, url, and license
+ assert captured_setup_kwargs.get("keywords") == "sql sqlalchemy etl loading utility", "Keywords mismatch"
+ assert captured_setup_kwargs.get("author") == "Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer", "Author mismatch"
+ assert captured_setup_kwargs.get("author_email") == "friedrich.lindenberg@gmail.com", "Author email mismatch"
+ assert captured_setup_kwargs.get("url") == "http://github.com/pudo/dataset", "URL mismatch"
+ assert captured_setup_kwargs.get("license") == "MIT", "License mismatch"
+
+ # Test packages: should return a list and not include excluded packages
+ packages = captured_setup_kwargs.get("packages")
+ assert isinstance(packages, list), "packages should be a list"
+ for pkg in packages:
+ assert pkg not in ("ez_setup", "examples", "test"), "Excluded package found"
+
+ # Test namespace_packages, include_package_data, zip_safe, test_suite, and entry_points
+ assert captured_setup_kwargs.get("namespace_packages") == [], "namespace_packages should be empty list"
+ assert captured_setup_kwargs.get("include_package_data") is False, "include_package_data should be False"
+ assert captured_setup_kwargs.get("zip_safe") is False, "zip_safe should be False"
+ assert captured_setup_kwargs.get("test_suite") == "test", "test_suite should be 'test'"
+ assert captured_setup_kwargs.get("entry_points") == {}, "entry_points should be {}"
+def test_extras_require_dev_content():
+ """Test that the 'dev' extras in extras_require contain exactly the expected packages."""
+ extras_require = captured_setup_kwargs.get("extras_require", {})
+ dev_deps = extras_require.get("dev", [])
+ expected_deps = [
+ "pip",
+ "pytest",
+ "wheel",
+ "flake8",
+ "coverage",
+ "psycopg2-binary",
+ "PyMySQL",
+ "cryptography",
+ ]
+ for dep in expected_deps:
+ assert dep in dev_deps, f"dev extras_require missing {dep}"
+ # Check that there are no additional dependencies in dev extras
+ assert len(dev_deps) == len(expected_deps), "dev extras_require contains unexpected dependencies"
+
+def test_install_requires_is_list_of_strings():
+ """Test that install_requires is a list and that each element is a string."""
+ install_requires = captured_setup_kwargs.get("install_requires", [])
+ assert isinstance(install_requires, list), "install_requires should be a list"
+ for requirement in install_requires:
+ assert isinstance(requirement, str), "Each install requirement should be a string"
+def test_readme_not_found(monkeypatch):
+ """Test that missing README.md raises FileNotFoundError during setup module import."""
+ # Override open to simulate a missing README.md file
+ def missing_open(filename, *args, **kwargs):
+ if filename == "README.md":
+ raise FileNotFoundError("No README.md found")
+ return open(filename, *args, **kwargs)
+ monkeypatch.setattr(builtins, "open", missing_open)
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+ with pytest.raises(FileNotFoundError):
+ import setup
+
+def test_multiple_reloads(monkeypatch):
+ """Test that reloading setup.py multiple times produces consistent configuration."""
+ import setup
+ captured_first = captured_setup_kwargs.copy()
+ importlib.reload(setup)
+ captured_second = captured_setup_kwargs.copy()
+ assert captured_first == captured_second, "Configuration should be consistent across reloads"
\ No newline at end of file
diff --git a/tests/test_util.py b/tests/test_util.py
new file mode 100644
index 0000000..a369b2a
--- /dev/null
+++ b/tests/test_util.py
@@ -0,0 +1,328 @@
+import pytest
+from collections import OrderedDict, namedtuple
+from urllib.parse import urlparse, parse_qs
+from dataset.util import (convert_row, iter_result_proxy, make_sqlite_url, ResultIter,
+ normalize_column_name, normalize_column_key, normalize_table_name,
+ safe_url, index_name, pad_chunk_columns, row_type, QUERY_STEP)
+from sqlalchemy.exc import ResourceClosedError
+
+# A fake ResultProxy object to simulate fetchall, fetchmany, keys and close behavior.
+class FakeResultProxy:
+ def __init__(self, data, step=None):
+ self.data = data
+ self.index = 0
+ self.step = step
+
+ def fetchall(self):
+ if self.index >= len(self.data):
+ return []
+ chunk = self.data[self.index:]
+ self.index = len(self.data)
+ return chunk
+
+ def fetchmany(self, size):
+ if self.index >= len(self.data):
+ return []
+ chunk = self.data[self.index:self.index+size]
+ self.index += size
+ return chunk
+
+ def keys(self):
+ if self.data:
+ # Assume that each row is a namedtuple having _fields
+ return list(self.data[0]._fields) if hasattr(self.data[0], "_fields") else []
+ return []
+
+ def close(self):
+ pass
+
+def test_convert_row_valid():
+ """Test convert_row with a valid namedtuple row."""
+ TestRow = namedtuple("TestRow", ["a", "b"])
+ row = TestRow(1, 2)
+ result = convert_row(OrderedDict, row)
+ assert isinstance(result, OrderedDict)
+ assert result["a"] == 1
+ assert result["b"] == 2
+
+def test_convert_row_none():
+ """Test convert_row with None input."""
+ assert convert_row(OrderedDict, None) is None
+
+def test_iter_result_proxy_fetchall():
+ """Test iter_result_proxy using fetchall method."""
+ TestRow = namedtuple("TestRow", ["a"])
+ data = [TestRow(1), TestRow(2)]
+ rp = FakeResultProxy(data)
+ results = list(iter_result_proxy(rp))
+ assert len(results) == 2
+ assert results[0].a == 1
+ assert results[1].a == 2
+
+def test_iter_result_proxy_fetchmany():
+ """Test iter_result_proxy using fetchmany method with step parameter."""
+ TestRow = namedtuple("TestRow", ["a"])
+ data = [TestRow(10), TestRow(20), TestRow(30)]
+ rp = FakeResultProxy(data)
+ results = list(iter_result_proxy(rp, step=2))
+ assert len(results) == 3
+ assert results[0].a == 10
+ assert results[2].a == 30
+
+def test_make_sqlite_url_no_params():
+ """Test make_sqlite_url with no optional parameters."""
+ path = "test.db"
+ url = make_sqlite_url(path)
+ assert url == "sqlite:///" + path
+
+def test_make_sqlite_url_with_params():
+ """Test make_sqlite_url with various parameters."""
+ path = "test.db"
+ url = make_sqlite_url(path, cache="shared", timeout=5.0, mode="ro",
+ check_same_thread=False, immutable=True, nolock=True)
+ # Expect URL to start with schema "sqlite:///file:test.db?"
+ assert url.startswith("sqlite:///file:" + path + "?")
+ parsed = urlparse(url)
+ qs = parse_qs(parsed.query)
+ assert qs.get("cache") == ["shared"]
+ assert qs.get("timeout") == ["5.0"]
+ assert qs.get("mode") == ["ro"]
+ assert qs.get("nolock") == ["1"]
+ assert qs.get("immutable") == ["1"]
+ assert qs.get("check_same_thread") == ["false"]
+ assert qs.get("uri") == ["true"]
+
+def test_result_iter_normal():
+ """Test ResultIter iteration with normal data."""
+ TestRow = namedtuple("TestRow", ["a", "b"])
+ data = [TestRow(1, 2), TestRow(3, 4)]
+ rp = FakeResultProxy(data)
+ result_iter = ResultIter(rp, row_type=OrderedDict, step=1)
+ rows = list(result_iter)
+ assert len(rows) == 2
+ assert rows[0]["a"] == 1
+ result_iter.close() # should not raise an exception
+
+def test_result_iter_resource_closed():
+ """Test ResultIter handles ResourceClosedError."""
+ class FakeClosedCursor:
+ def keys(self):
+ raise ResourceClosedError
+ def close(self):
+ pass
+ fake_cursor = FakeClosedCursor()
+ result_iter = ResultIter(fake_cursor)
+ rows = list(result_iter)
+ assert rows == []
+
+def test_normalize_column_name_valid():
+ """Test normalize_column_name with valid names."""
+ name = " valid_name "
+ norm = normalize_column_name(name)
+ assert " " not in norm
+ assert len(norm) <= 63
+
+def test_normalize_column_name_invalid():
+ """Test normalize_column_name with invalid names containing '.' or '-'."""
+ with pytest.raises(ValueError):
+ normalize_column_name("inva.lid")
+ with pytest.raises(ValueError):
+ normalize_column_name("inva-lid")
+
+def test_normalize_column_name_non_string():
+ """Test normalize_column_name with non-string input."""
+ with pytest.raises(ValueError):
+ normalize_column_name(123)
+
+def test_normalize_column_name_length():
+ """Test normalize_column_name trims name to 63 bytes max considering UTF-8 encoding."""
+ long_name = "a" * 100
+ norm = normalize_column_name(long_name)
+ assert len(norm.encode("utf-8")) < 64
+
+def test_normalize_column_key():
+ """Test normalize_column_key returns upper case without spaces and None for invalid values."""
+ assert normalize_column_key(" column ") == "COLUMN"
+ assert normalize_column_key(None) is None
+
+def test_normalize_table_name_valid():
+ """Test normalize_table_name with valid input."""
+ name = " my_table "
+ norm = normalize_table_name(name)
+ assert norm == name.strip()[:63]
+
+def test_normalize_table_name_invalid():
+ """Test normalize_table_name raises error on invalid table names."""
+ with pytest.raises(ValueError):
+ normalize_table_name("")
+ with pytest.raises(ValueError):
+ normalize_table_name(123)
+
+def test_safe_url_no_password():
+ """Test safe_url does not modify a URL with no password."""
+ url = "sqlite:///test.db"
+ safe = safe_url(url)
+ assert safe == url
+
+def test_safe_url_with_password():
+ """Test safe_url masks the password in the URL."""
+ url = "postgresql://user:secret@localhost/db"
+ safe = safe_url(url)
+ assert "secret" not in safe
+ assert ":*****@" in safe
+
+def test_index_name():
+ """Test index_name generates a valid index name."""
+ table = "mytable"
+ columns = ["col1", "col2"]
+ idx = index_name(table, columns)
+ assert idx.startswith("ix_" + table + "_")
+ parts = idx.split("_")
+ assert len(parts[-1]) == 16
+
+def test_pad_chunk_columns():
+ """Test pad_chunk_columns adds missing columns with None."""
+ chunk = [{"a": 1}, {"b": 2}, {"a": 3, "b": 4}]
+ columns = ["a", "b", "c"]
+ padded = pad_chunk_columns(chunk, columns)
+ for record in padded:
+ for col in columns:
+ assert col in record
+
+def test_pad_chunk_columns_empty():
+ """Test pad_chunk_columns with an empty chunk."""
+ padded = pad_chunk_columns([], ["a", "b"])
+ assert padded == []
+def test_iter_result_proxy_empty():
+ """Test iter_result_proxy yields nothing when there is no data."""
+ rp = FakeResultProxy([])
+ results = list(iter_result_proxy(rp))
+ assert results == []
+
+def test_index_name_empty_columns():
+ """Test index_name generates a valid index name even with an empty columns list."""
+ from hashlib import sha1
+ table = "test_table"
+ columns = []
+ idx = index_name(table, columns)
+ expected_key = sha1("".encode("utf-8")).hexdigest()[:16]
+ expected = "ix_%s_%s" % (table, expected_key)
+ assert idx == expected
+
+def test_convert_row_empty_fields():
+ """Test convert_row with a namedtuple that has no fields."""
+ TestRow = namedtuple("TestRow", [])
+ row = TestRow()
+ result = convert_row(OrderedDict, row)
+ assert isinstance(result, OrderedDict)
+ assert len(result) == 0
+
+def test_normalize_column_name_empty_after_strip():
+ """Test normalize_column_name raises a ValueError when name is empty after stripping."""
+ with pytest.raises(ValueError):
+ normalize_column_name(" ")
+
+def test_normalize_table_name_whitespace():
+ """Test normalize_table_name raises a ValueError when name is only whitespace."""
+ with pytest.raises(ValueError):
+ normalize_table_name(" ")
+def test_make_sqlite_url_invalid_cache():
+ """Test make_sqlite_url raises AssertionError when an invalid cache value is provided."""
+ with pytest.raises(AssertionError):
+ make_sqlite_url("test.db", cache="invalid")
+
+def test_make_sqlite_url_invalid_mode():
+ """Test make_sqlite_url raises AssertionError when an invalid mode value is provided."""
+ with pytest.raises(AssertionError):
+ make_sqlite_url("test.db", mode="invalid")
+
+def test_result_iter_iter_returns_self():
+ """Test that __iter__ returns self for ResultIter objects."""
+ TestRow = namedtuple("TestRow", ["a"])
+ data = [TestRow(1)]
+ rp = FakeResultProxy(data)
+ iter_obj = ResultIter(rp)
+ # __iter__ should return the object itself
+ assert iter(iter_obj) is iter_obj
+
+def test_pad_chunk_columns_non_dict():
+ """Test that pad_chunk_columns raises AttributeError when a record is not a dict."""
+ # Here we pass a list of tuples (not dict objects)
+ chunk = [("a", 1), ("b", 2)]
+ with pytest.raises(AttributeError):
+ pad_chunk_columns(chunk, ["a", "b"])
+
+def test_normalize_column_key_non_string():
+ """Test that normalize_column_key returns None when provided a non-string input."""
+ assert normalize_column_key(123) is None
+
+def test_safe_url_empty_password():
+ """Test that safe_url masks the password when the URL contains an empty password."""
+ url = "postgresql://user:@localhost/db"
+ safe = safe_url(url)
+ # Check that the empty password gets replaced by :*****@
+ assert ":*****@" in safe
+
+def test_result_iter_close_called():
+ """Test that calling close on a ResultIter instance calls the underlying cursor's close method."""
+ class Dummy:
+ def __init__(self):
+ self.closed = False
+ def close(self):
+ self.closed = True
+ def keys(self):
+ return ["a"]
+ dummy = Dummy()
+ result_iter = ResultIter(dummy)
+ result_iter.close()
+ assert dummy.closed is True
+def test_make_sqlite_url_check_same_thread_true():
+ """Test that make_sqlite_url returns URL without extra parameters if check_same_thread is True."""
+ path = "test.db"
+ url = make_sqlite_url(path, check_same_thread=True)
+ # When no optional parameter is provided, it should simply return the basic URL.
+ assert url == "sqlite:///" + path
+
+def test_convert_row_invalid_row():
+ """Test that convert_row raises an AttributeError when row does not have _fields attribute."""
+ class Dummy:
+ a = 1
+ dummy = Dummy()
+ with pytest.raises(AttributeError):
+ convert_row(OrderedDict, dummy)
+
+def test_safe_url_special_characters():
+ """Test that safe_url correctly masks passwords that contain special characters."""
+ url = "postgresql://user:my$ecret*pass@localhost/db"
+ safe = safe_url(url)
+ # Ensure that the original password does not appear in the safe URL.
+ assert "my$ecret*pass" not in safe
+ assert ":*****@" in safe
+
+def test_pad_chunk_columns_overrides_none():
+ """Test that pad_chunk_columns does not override keys that already exist, even if they are None."""
+ chunk = [{"a": None}, {"b": 2}]
+ columns = ["a", "b", "c"]
+ padded = pad_chunk_columns(chunk, columns)
+ # The existing 'a' with a value of None should not be changed;
+ # missing keys 'b' and 'c' should be added where they do not already exist.
+ assert padded[0]["a"] is None
+ assert "b" in padded[0]
+ assert "c" in padded[0]
+
+def test_result_iter_multiple_iterations():
+ """Test that iterating a ResultIter once exhausts it and further calls to next() raise StopIteration."""
+ TestRow = namedtuple("TestRow", ["a"])
+ data = [TestRow(1), TestRow(2)]
+ rp = FakeResultProxy(data)
+ iter_obj = ResultIter(rp, row_type=OrderedDict, step=1)
+ result1 = list(iter_obj)
+ # Now iter_obj should be exhausted. Next call to next() should raise StopIteration.
+ with pytest.raises(StopIteration):
+ next(iter_obj)
+
+def test_normalize_column_key_empty_string():
+ """Test that normalize_column_key returns an empty string when provided with a whitespace-only string."""
+ result = normalize_column_key(" ")
+ # The string returns empty due to stripping spaces.
+ assert result == ""
\ No newline at end of file