diff --git a/codebeaver.yml b/codebeaver.yml
new file mode 100644
index 0000000..419e243
--- /dev/null
+++ b/codebeaver.yml
@@ -0,0 +1,2 @@
+from: pytest
+# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
\ No newline at end of file
diff --git a/test/test_table.py b/test/test_table.py
new file mode 100644
index 0000000..c2af325
--- /dev/null
+++ b/test/test_table.py
@@ -0,0 +1,612 @@
+import os
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from sqlalchemy.types import BIGINT, INTEGER, VARCHAR, TEXT
+from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
+import threading
+from dataset import chunked
+
+from .conftest import TEST_DATA, TEST_CITY_1
+
+
+def test_insert(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ last_id = table.insert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_insert_ignore(table):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_ignore_all_key(table):
+ for i in range(0, 4):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_json(table):
+ last_id = table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "place": "Berlin",
+ "info": {
+ "currency": "EUR",
+ "language": "German",
+ "population": 3292365,
+ },
+ }
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_upsert(table):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_single_column(db):
+ table = db["banana_single_col"]
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+
+
+def test_upsert_all_key(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ for i in range(0, 2):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_id(db):
+ table = db["banana_with_id"]
+ data = dict(id=10, title="I am a banana!")
+ table.upsert(data, ["id"])
+ assert len(table) == 1, len(table)
+
+
+def test_update_while_iter(table):
+ for row in table:
+ row["foo"] = "bar"
+ table.update(row, ["place", "date"])
+ assert len(table) == len(TEST_DATA), len(table)
+
+
+def test_weird_column_names(table):
+ with pytest.raises(ValueError):
+ table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "foo.bar": "Berlin",
+ "qux.bar": "Huhu",
+ }
+ )
+
+
+def test_cased_column_names(db):
+ tbl = db["cased_column_names"]
+ tbl.insert({"place": "Berlin"})
+ tbl.insert({"Place": "Berlin"})
+ tbl.insert({"PLACE ": "Berlin"})
+ assert len(tbl.columns) == 2, tbl.columns
+ assert len(list(tbl.find(Place="Berlin"))) == 3
+ assert len(list(tbl.find(place="Berlin"))) == 3
+ assert len(list(tbl.find(PLACE="Berlin"))) == 3
+
+
+def test_invalid_column_names(db):
+ tbl = db["weather"]
+ with pytest.raises(ValueError):
+ tbl.insert({None: "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"": "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"-": "banana"})
+
+
+def test_delete(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ original_count = len(table)
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ # Test bad use of API
+ with pytest.raises(ArgumentError):
+ table.delete({"place": "Berlin"})
+ assert len(table) == original_count, len(table)
+
+ assert table.delete(place="Berlin") is True, "should return 1"
+ assert len(table) == len(TEST_DATA), len(table)
+ assert table.delete() is True, "should return non zero"
+ assert len(table) == 0, len(table)
+
+
+def test_repr(table):
+ assert (
+ repr(table) == "
"
+ ), "the representation should be "
+
+
+def test_delete_nonexist_entry(table):
+ assert (
+ table.delete(place="Berlin") is False
+ ), "entry not exist, should fail to delete"
+
+
+def test_find_one(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ d = table.find_one(place="Berlin")
+ assert d["temperature"] == -10, d
+ d = table.find_one(place="Atlantis")
+ assert d is None, d
+
+
+def test_count(table):
+ assert len(table) == 6, len(table)
+ length = table.count(place=TEST_CITY_1)
+ assert length == 3, length
+
+
+def test_find(table):
+ ds = list(table.find(place=TEST_CITY_1))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _step=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=1, _step=2))
+ assert len(ds) == 1, ds
+ ds = list(table.find(_step=2))
+ assert len(ds) == len(TEST_DATA), ds
+ ds = list(table.find(order_by=["temperature"]))
+ assert ds[0]["temperature"] == -1, ds
+ ds = list(table.find(order_by=["-temperature"]))
+ assert ds[0]["temperature"] == 8, ds
+ ds = list(table.find(table.table.columns.temperature > 4))
+ assert len(ds) == 3, ds
+
+
+def test_find_dsl(table):
+ ds = list(table.find(place={"like": "%lw%"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={">": 5}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={">=": 5}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={"<": 0}))
+ assert len(ds) == 1, ds
+ ds = list(table.find(temperature={"<=": 0}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={"!=": -1}))
+ assert len(ds) == 5, ds
+ ds = list(table.find(temperature={"between": [5, 8]}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"=": "G€lway"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"ilike": "%LwAy"}))
+ assert len(ds) == 3, ds
+
+
+def test_offset(table):
+ ds = list(table.find(place=TEST_CITY_1, _offset=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _offset=2))
+ assert len(ds) == 1, ds
+
+
+def test_streamed_update(table):
+ ds = list(table.find(place=TEST_CITY_1, _streamed=True, _step=1))
+ assert len(ds) == 3, len(ds)
+ for row in table.find(place=TEST_CITY_1, _streamed=True, _step=1):
+ row["temperature"] = -1
+ table.update(row, ["id"])
+
+
+def test_distinct(table):
+ x = list(table.distinct("place"))
+ assert len(x) == 2, x
+ x = list(table.distinct("place", "date"))
+ assert len(x) == 6, x
+ x = list(
+ table.distinct(
+ "place",
+ "date",
+ table.table.columns.date >= datetime(2011, 1, 2, 0, 0),
+ )
+ )
+ assert len(x) == 4, x
+
+ x = list(table.distinct("temperature", place="B€rkeley"))
+ assert len(x) == 3, x
+ x = list(table.distinct("temperature", place=["B€rkeley", "G€lway"]))
+ assert len(x) == 6, x
+
+
+def test_insert_many(table):
+ data = TEST_DATA * 100
+ table.insert_many(data, chunk_size=13)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert(table):
+ data = TEST_DATA * 100
+ with chunked.ChunkedInsert(table) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert_callback(table):
+ data = TEST_DATA * 100
+ N = 0
+
+ def callback(queue):
+ nonlocal N
+ N += len(queue)
+
+ with chunked.ChunkedInsert(table, callback=callback) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(data) == N
+ assert len(table) == len(data) + 6
+
+
+def test_update_many(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)])
+ tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id")
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"]
+
+
+def test_chunked_update(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many(
+ [
+ dict(temp=10, location="asdf"),
+ dict(temp=20, location="qwer"),
+ dict(temp=30, location="asdf"),
+ ]
+ )
+ db.commit()
+
+ chunked_tbl = chunked.ChunkedUpdate(tbl, ["id"])
+ chunked_tbl.update(dict(id=1, temp=50))
+ chunked_tbl.update(dict(id=2, location="asdf"))
+ chunked_tbl.update(dict(id=3, temp=50))
+ chunked_tbl.flush()
+ db.commit()
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50
+ assert tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf"
+
+
+def test_upsert_many(db):
+ # Also tests updating on records with different attributes
+ tbl = db["upsert_many_test"]
+
+ W = 100
+ tbl.upsert_many([dict(age=10), dict(weight=W)], "id")
+ assert tbl.find_one(id=1)["age"] == 10
+
+ tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id")
+ assert tbl.find_one(id=2)["weight"] == W / 2
+
+
+def test_drop_operations(table):
+ assert table._table is not None, "table shouldn't be dropped yet"
+ table.drop()
+ assert table._table is None, "table should be dropped now"
+ assert list(table.all()) == [], table.all()
+ assert table.count() == 0, table.count()
+
+
+def test_table_drop(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+
+
+def test_table_drop_then_create(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+ db["weather"].insert({"foo": "bar"})
+
+
+def test_columns(table):
+ cols = table.columns
+ assert len(list(cols)) == 4, "column count mismatch"
+ assert "date" in cols and "temperature" in cols and "place" in cols
+
+
+def test_drop_column(table):
+ try:
+ table.drop_column("date")
+ assert "date" not in table.columns
+ except RuntimeError:
+ pass
+
+
+def test_iter(table):
+ c = 0
+ for row in table:
+ c += 1
+ assert c == len(table)
+
+
+def test_update(table):
+ date = datetime(2011, 1, 2)
+ res = table.update(
+ {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"]
+ )
+ assert res, "update should return True"
+ m = table.find_one(place=TEST_CITY_1, date=date)
+ assert m["temperature"] == -10, (
+ "new temp. should be -10 but is %d" % m["temperature"]
+ )
+
+
+def test_create_column(db, table):
+ flt = db.types.float
+ table.create_column("foo", flt)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["foo"].type
+ assert "foo" in table.columns, table.columns
+
+
+def test_ensure_column(db, table):
+ flt = db.types.float
+ table.create_column_by_example("foo", 0.1)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["bar"].type
+ table.create_column_by_example("bar", 1)
+ assert "bar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bar"].type, BIGINT), table.table.c["bar"].type
+ table.create_column_by_example("pippo", "test")
+ assert "pippo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["pippo"].type, TEXT), table.table.c["pippo"].type
+ table.create_column_by_example("bigbar", 11111111111)
+ assert "bigbar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bigbar"].type, BIGINT), table.table.c[
+ "bigbar"
+ ].type
+ table.create_column_by_example("littlebar", -11111111111)
+ assert "littlebar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["littlebar"].type, BIGINT), table.table.c[
+ "littlebar"
+ ].type
+
+
+def test_key_order(db, table):
+ res = db.query("SELECT temperature, place FROM weather LIMIT 1")
+ keys = list(res.next().keys())
+ assert keys[0] == "temperature"
+ assert keys[1] == "place"
+
+
+def test_empty_query(table):
+ empty = list(table.find(place="not in data"))
+ assert len(empty) == 0, empty
+
+def test_order_by_invalid(table):
+ """Test that order_by with an invalid column does not break and that valid ordering still works."""
+ # Insert an additional row so ordering makes sense
+ table.insert({"date": datetime(2022, 1, 1), "temperature": 5, "place": "TestCity"})
+ # Pass an invalid order_by column ("non_existent") before a valid one ("temperature")
+ results = list(table.find(order_by=["non_existent", "temperature"]))
+ # We expect that only a valid ordering is applied eventually.
+ # Here we simply check that rows have the "temperature" key.
+ for row in results:
+ assert "temperature" in row
+
+def test_between_invalid(table):
+ """Test that providing an invalid number of arguments to a 'between' filter raises an error."""
+ with pytest.raises(ValueError):
+ # Passing a list with one element should raise an error because tuple unpacking fails.
+ list(table.find(temperature={"between": [5]}))
+
+def test_find_unsupported_operator(table):
+ """Test that using an unsupported operator in a filter returns no results."""
+ # Insert a known row
+ table.insert({"date": datetime(2022, 2, 2), "temperature": 15, "place": "TestTown"})
+ results = list(table.find(temperature={"unknown": 15}))
+ # An unsupported operator should resolve to a false() clause so that no rows match.
+ assert results == []
+
+def test_update_no_change(table):
+ """Test update when the update payload is fully consumed by key filters resulting in no new data to update."""
+ # Insert a unique row.
+ row_id = table.insert({"date": datetime(2020, 5, 5), "temperature": 20, "place": "Unique"})
+ row = table.find_one(id=row_id)
+ # Calling update with keys only (i.e. without any remaining data) should return the count of matching rows.
+ count = table.update(row, ["id"])
+ # At least one row should match
+ assert count >= 1
+
+def test_insert_non_dict(table):
+ """Test that inserting a non-dictionary value causes an error."""
+ with pytest.raises(AttributeError):
+ table.insert("not a dict")
+
+def test_create_index_duplicate(table):
+ """Test that creating an index on the same column twice does not create duplicate indexes or raise an error."""
+ # First creation of the index
+ table.create_index(["place"])
+ # Second creation should simply be a no-op.
+ table.create_index(["place"])
+ # Fetch indexes and verify that at least one index covers the 'place' column.
+ indexes = table.db.inspect.get_indexes(table.name, schema=table.db.schema)
+ assert any("place" in idx.get("column_names", []) for idx in indexes)
+
+def test_drop_nonexistent_column(table):
+ """Test that attempting to drop a non-existent column does not crash the operation."""
+ if table.db.is_sqlite:
+ pytest.skip("SQLite does not support dropping columns, skipping test.")
+ try:
+ table.drop_column("nonexistent_column")
+ except Exception:
+ pytest.fail("Dropping a non-existent column raised an exception")
+
+ # Clear the table for a controlled test.
+ table.delete()
+ # Insert rows with duplicate 'place' values.
+ table.insert({"date": datetime(2022, 3, 3), "temperature": 10, "place": "SameTown"})
+ table.insert({"date": datetime(2022, 3, 4), "temperature": 12, "place": "SameTown"})
+ distinct_rows = list(table.distinct("place"))
+ assert len(distinct_rows) == 1
+
+def test_delete_with_filters(table):
+ """Test deletion with multiple filter criteria to ensure only matching rows are removed."""
+ # Insert distinct rows.
+ table.insert({"date": datetime(2021, 1, 1), "temperature": 0, "place": "DeleteTown"})
+ table.insert({"date": datetime(2021, 1, 2), "temperature": 1, "place": "DeleteTown"})
+ original_count = len(table)
+ # Delete the row with the specified date.
+ deleted = table.delete(place="DeleteTown", date=datetime(2021, 1, 1))
+ assert deleted is True
+ new_count = len(table)
+ assert new_count == original_count - 1
+
+def test_drop_table_idempotent(db):
+ """Test that dropping a table twice does not cause an exception."""
+ table = db["temp_table"]
+ table.insert({"date": datetime(2023, 1, 1), "temperature": 5, "place": "TempPlace"})
+ table.drop()
+ # Dropping an already dropped table should not raise an exception.
+ try:
+ table.drop()
+ except Exception:
+ pytest.fail("Dropping an already dropped table raised an exception")
+
+def test_repr_str_format(table, monkeypatch):
+ """Test that the __repr__ method returns a proper string representation with the table name."""
+ rep = repr(table)
+ assert rep.startswith("")
+ name = table.table.name if table.exists else table.name
+ assert name in rep
+ orig_in_trans = table.db.in_transaction
+
+ monkeypatch.setattr(type(table.db), "in_transaction", property(lambda self: True))
+ monkeypatch.setattr(threading, "active_count", lambda: 2)
+ with pytest.warns(RuntimeWarning, match="Changing the database schema inside a transaction"):
+ table._threading_warn()
+ table._threading_warn()
+
+def test_keys_to_args(table):
+ """Test the _keys_to_args method to ensure keys are separated correctly from the row."""
+ row = {"a": 1, "b": 2}
+ args, remainder = table._keys_to_args(row, "a")
+ assert args == {"a": 1}
+ assert remainder == {"b": 2}
+
+def test_args_to_order_by_invalid(table):
+ """Test _args_to_order_by with an invalid column name and a valid one."""
+ # Insert a column 'valid' to ensure it exists.
+ if not table.has_column("valid"):
+ table.create_column("valid", table.db.types.integer)
+ orderings = table._args_to_order_by(["nonexistent", "valid"])
+ # 'nonexistent' should be ignored, and ordering for 'valid' should appear.
+ assert len(orderings) == 1
+
+def test_sync_columns_auto_create(db):
+ """Test that missing columns are auto-created when ensure_schema is enabled."""
+ tbl = db["auto_create_test"]
+ # Ensure auto_create is enabled so that missing columns get created.
+ tbl._auto_create = True
+ initial_cols = set(tbl.columns)
+ new_data = {"new_col": "test_value", "another_col": 123}
+ tbl.insert(new_data, ensure=True)
+ updated_cols = set(tbl.columns)
+ assert "new_col" in updated_cols
+ assert "another_col" in updated_cols
+
+def test_update_return_count(table):
+ """Test that update returns the count when no new data is provided to update."""
+ # Insert a row to update.
+ row_id = table.insert({"date": datetime(2024, 1, 1), "temperature": 20, "place": "TestUpdate"})
+ row = table.find_one(id=row_id)
+ # Call update with keys that fully consume the row, so no new values remain.
+ count = table.update(row.copy(), ["id"])
+ assert count >= 1
+def test_generate_clause_invalid_operator(table):
+ """Test that _generate_clause returns a false() clause when an unsupported operator is provided."""
+ # Call _generate_clause with an invalid operator and verify that it returns a false() SQL expression.
+ clause = table._generate_clause("temperature", "unknown_op", 10)
+ compiled_clause = str(clause.compile(compile_kwargs={"literal_binds": True}))
+ # The compiled clause should contain "false", regardless of uppercase/lowercase.
+ assert "false" in compiled_clause.lower()
+
+def test_args_to_clause_nonexistent_column(table):
+ """Test that _args_to_clause returns a false clause part for a nonexistent column."""
+ clause = table._args_to_clause({"nonexistent": 42})
+ compiled_clause = str(clause.compile(compile_kwargs={"literal_binds": True}))
+ assert "false" in compiled_clause.lower()
+
+def test_delete_using_clause_expression(table):
+ """Test deletion using a clause expression while also combining keyword filters."""
+ # Insert a row that will be deleted using a clause expression.
+ table.insert({"date": datetime(2020, 1, 1), "temperature": 25, "place": "DeleteExpr"})
+ from sqlalchemy import literal
+ clause_expr = (table.table.c.temperature == 25)
+ # Combine the clause expression with an additional keyword filter.
+ deleted = table.delete(clause_expr, place="DeleteExpr")
+ assert deleted is True
+ # Now the row should be gone; deleting it again should return False.
+ deleted_again = table.delete(clause_expr, place="DeleteExpr")
+ assert deleted_again is False
+
+def test_has_index_primary_key(db):
+ """Test that has_index returns True for the primary key column of a table."""
+ tbl = db["index_primary_test"]
+ # Insert a row to force table creation (which creates the primary key column 'id').
+ tbl.insert({"data": "test"})
+ # Calling has_index on the primary key column should return True.
+ assert tbl.has_index("id") is True
+
+def test_args_to_order_by_mixed(table):
+ """Test _args_to_order_by with a mix of valid and invalid column names."""
+ # First ensure that the "valid_order" column exists.
+ if not table.has_column("valid_order"):
+ table.create_column("valid_order", table.db.types.integer)
+ # Insert a dummy row including the new column.
+ table.insert({"date": datetime(2021, 1, 1), "temperature": 0, "place": "TestPlace", "valid_order": 1})
+ orderings = table._args_to_order_by(["nonexistent", "valid_order"])
+ # Only "valid_order" is valid so we expect one ordering.
+ assert len(orderings) == 1
+
+# End of inserted tests.
\ No newline at end of file
diff --git a/tests/test_chunked.py b/tests/test_chunked.py
new file mode 100644
index 0000000..f013d21
--- /dev/null
+++ b/tests/test_chunked.py
@@ -0,0 +1,191 @@
+import pytest
+import itertools
+
+from dataset.chunked import ChunkedInsert, ChunkedUpdate, InvalidCallback
+
+class DummyTable:
+ def __init__(self):
+ self.inserted = []
+ self.updated = []
+
+ def insert_many(self, items):
+ # capture a copy to avoid mutation issues
+ self.inserted.append(list(items))
+
+ def update_many(self, items, keys):
+ self.updated.append((list(items), keys))
+
+# Tests for ChunkedInsert
+
+def test_chunked_insert_invalid_callback():
+ """Test that invalid callback raises InvalidCallback."""
+ dummy = DummyTable()
+ with pytest.raises(InvalidCallback):
+ ChunkedInsert(dummy, callback="not_callable")
+
+def test_chunked_insert_auto_flush():
+ """Test that insert auto flushes when chunk size is reached and callback is executed."""
+ dummy = DummyTable()
+ callback_calls = []
+ def callback(queue):
+ callback_calls.append(list(queue)) # capture snapshot of the queue before insertion
+ inserter = ChunkedInsert(dummy, chunksize=2, callback=callback)
+ # Insert two items; first item is buffered, second triggers flush.
+ row1 = {"a": 1}
+ row2 = {"b": 2}
+ inserter.insert(row1)
+ # At this point, the queue has one item.
+ assert len(inserter.queue) == 1
+ inserter.insert(row2) # triggers flush
+ # After flush, queue should be empty.
+ assert len(inserter.queue) == 0
+ # Verify callback was called exactly once.
+ assert len(callback_calls) == 1
+ # Check that inserted items have fields normalized: each row contains all keys seen so far.
+ inserted = dummy.inserted[0]
+ for item in inserted:
+ assert "a" in item
+ assert "b" in item
+
+def test_chunked_insert_manual_flush():
+ """Test that manual flush works correctly for ChunkedInsert."""
+ dummy = DummyTable()
+ inserter = ChunkedInsert(dummy, chunksize=3)
+ row = {"x": 10}
+ inserter.insert(row)
+ # Queue is not auto flushed because chunksize is not reached.
+ assert len(inserter.queue) == 1
+ inserter.flush() # manually flush
+ assert len(inserter.queue) == 0
+ assert len(dummy.inserted) == 1
+ # After flush, the row should contain the key "x".
+ assert dummy.inserted[0][0]["x"] == 10
+
+def test_chunked_insert_context_manager():
+ """Test that ChunkedInsert works correctly as a context manager."""
+ dummy = DummyTable()
+ with ChunkedInsert(dummy, chunksize=5) as inserter:
+ inserter.insert({"p": 100})
+ inserter.insert({"q": 200})
+ # Not reaching chunksize, so queue should be pending
+ assert len(inserter.queue) == 2
+ # On context exit, flush should have been automatically called.
+ assert len(dummy.inserted) == 1
+ # Verify that fields are normalized across inserted rows.
+ inserted = dummy.inserted[0]
+ for item in inserted:
+ # The union of seen keys is {"p", "q"}.
+ assert "p" in item
+ assert "q" in item
+
+# Tests for ChunkedUpdate
+
+def test_chunked_update_auto_flush():
+ """Test that update auto flushes and groups rows correctly with callback."""
+ dummy = DummyTable()
+ callback_calls = []
+ def callback(queue):
+ callback_calls.append(list(queue))
+ updater = ChunkedUpdate(dummy, keys=["id"], chunksize=2, callback=callback)
+ # Prepare updates with the same keys.
+ row1 = {"id": 1, "val": "a"}
+ row2 = {"id": 1, "val": "b"}
+ updater.update(row1)
+ # Queue not flushed yet.
+ assert len(updater.queue) == 1
+ updater.update(row2) # triggers flush
+ # After flush, queue should be empty.
+ assert len(updater.queue) == 0
+ # Callback should have been called once.
+ assert len(callback_calls) == 1
+ # Verify that update_many was called with grouped rows (both rows in one group since they share the same keys).
+ assert len(dummy.updated) == 1
+ group, keys_passed = dummy.updated[0]
+ assert keys_passed == ["id"]
+ assert len(group) == 2
+
+def test_chunked_update_manual_flush_with_different_keys():
+ """Test manual flush for ChunkedUpdate with rows having different keys."""
+ dummy = DummyTable()
+ updater = ChunkedUpdate(dummy, keys=["key"], chunksize=5)
+ # Insert rows with different sets of keys.
+ row1 = {"key": 1, "data": "x"}
+ row2 = {"data": "y"} # lacks key 'key'
+ updater.update(row1)
+ updater.update(row2)
+ # Manually flush without reaching the chunksize.
+ updater.flush()
+ # Since groupby groups items based on dict.keys, these rows should fall into 2 distinct groups.
+ assert len(dummy.updated) == 2
+
+def test_chunked_update_context_manager():
+ """Test that ChunkedUpdate works correctly as a context manager."""
+ dummy = DummyTable()
+ with ChunkedUpdate(dummy, keys=["uid"], chunksize=3) as updater:
+ updater.update({"uid": 10, "status": "ok"})
+ updater.update({"uid": 10, "status": "fail"})
+ # Upon context exit, flush should have been automatically called.
+ # Expect one group since both rows have the same keys.
+ assert len(dummy.updated) == 1
+ group, keys_passed = dummy.updated[0]
+ assert keys_passed == ["uid"]
+ assert len(group) == 2
+def test_chunked_insert_empty_flush():
+ """Test that flush on empty queue for ChunkedInsert calls the callback and insert_many even if no items exist."""
+ dummy = DummyTable()
+ callback_calls = []
+ # Define a callback that appends a copy of the queue when called
+ def callback(queue):
+ callback_calls.append(list(queue))
+ inserter = ChunkedInsert(dummy, chunksize=2, callback=callback)
+ # Directly call flush even though no items have been inserted.
+ inserter.flush()
+ # flush calls callback (with empty queue) and insert_many with empty list.
+ # DummyTable.insert_many will append an empty list to dummy.inserted.
+ assert callback_calls == [[]]
+ assert dummy.inserted == [[]]
+
+def test_chunked_update_empty_flush():
+ """Test that flush on empty queue for ChunkedUpdate calls the callback and update_many even if no items exist."""
+ dummy = DummyTable()
+ callback_calls = []
+ def callback(queue):
+ callback_calls.append(list(queue))
+ updater = ChunkedUpdate(dummy, keys=["id"], chunksize=2, callback=callback)
+ updater.flush()
+ # Since there are no items, callback is called with an empty list and update_many is never invoked.
+ assert callback_calls == [[]]
+ # No update_many call should result in dummy.updated remaining empty.
+ assert dummy.updated == []
+
+def test_chunked_insert_context_manager_exception():
+ """Test that ChunkedInsert flushes on context exit even when an exception is raised within the context."""
+ dummy = DummyTable()
+ try:
+ with ChunkedInsert(dummy, chunksize=5) as inserter:
+ inserter.insert({"p": 300})
+ raise ValueError("ChunkedInsert exception")
+ except ValueError:
+ pass
+ # Upon context exit, flush should have been called.
+ # DummyTable.insert_many should have been invoked once with the pending row.
+ assert len(dummy.inserted) == 1
+ inserted = dummy.inserted[0]
+ # Normalization should happen (the union of keys is {"p"})
+ for item in inserted:
+ assert "p" in item
+
+def test_chunked_update_context_manager_exception():
+ """Test that ChunkedUpdate flushes on context exit even when an exception is raised within the context."""
+ dummy = DummyTable()
+ try:
+ with ChunkedUpdate(dummy, keys=["uid"], chunksize=5) as updater:
+ updater.update({"uid": 20, "flag": True})
+ raise ValueError("ChunkedUpdate exception")
+ except ValueError:
+ pass
+ # Flush should have been called on exit, causing update_many to be invoked.
+ assert len(dummy.updated) == 1
+ group, keys_passed = dummy.updated[0]
+ assert keys_passed == ["uid"]
+ assert len(group) == 1
\ No newline at end of file
diff --git a/tests/test_util.py b/tests/test_util.py
new file mode 100644
index 0000000..5044d2e
--- /dev/null
+++ b/tests/test_util.py
@@ -0,0 +1,281 @@
+import pytest
+from collections import namedtuple, OrderedDict
+from urllib.parse import urlparse, parse_qs
+from dataset.util import (
+ convert_row,
+ iter_result_proxy,
+ make_sqlite_url,
+ ResultIter,
+ normalize_column_name,
+ normalize_column_key,
+ normalize_table_name,
+ safe_url,
+ index_name,
+ pad_chunk_columns,
+ DatasetException,
+)
+from sqlalchemy.exc import ResourceClosedError
+
+class DummyResultProxy:
+ """A dummy ResultProxy to simulate fetchall and fetchmany behavior."""
+ def __init__(self, chunks, use_fetchall=True):
+ self.chunks = chunks.copy()
+ self.use_fetchall = use_fetchall
+
+ def fetchall(self):
+ if self.chunks:
+ return self.chunks.pop(0)
+ return []
+
+ def fetchmany(self, size):
+ if self.chunks:
+ chunk = self.chunks.pop(0)
+ return chunk[:size]
+ return []
+
+ def keys(self):
+ return ['a', 'b']
+
+ def close(self):
+ pass
+
+class DummyCursor:
+ """A dummy cursor simulating sqlite cursor behavior."""
+ def __init__(self, rows):
+ self._rows = iter(rows)
+ self.closed = False
+
+ def keys(self):
+ return ['a', 'b']
+
+ def fetchmany(self, size):
+ result = []
+ try:
+ for _ in range(size):
+ result.append(next(self._rows))
+ except StopIteration:
+ pass
+ return result
+
+ def fetchall(self):
+ return list(self._rows)
+
+ def close(self):
+ self.closed = True
+
+class DummyCursorClosed:
+ """A dummy cursor that immediately raises ResourceClosedError."""
+ def keys(self):
+ raise ResourceClosedError("Cursor closed")
+
+ def close(self):
+ pass
+
+class TestUtil:
+ """Test suite for functions in dataset/util.py"""
+
+ def test_convert_row(self):
+ """Test converting a namedtuple row into an OrderedDict."""
+ Row = namedtuple("Row", ["a", "b"])
+ row = Row(1, 2)
+ result = convert_row(OrderedDict, row)
+ assert isinstance(result, OrderedDict)
+ assert result["a"] == 1
+ assert result["b"] == 2
+
+ def test_convert_row_none(self):
+ """Test that convert_row returns None when the input row is None."""
+ result = convert_row(OrderedDict, None)
+ assert result is None
+
+ def test_iter_result_proxy_fetchall(self):
+ """Test iter_result_proxy using fetchall."""
+ rows = [(1, 2), (3, 4)]
+ rp = DummyResultProxy(chunks=[rows], use_fetchall=True)
+ results = list(iter_result_proxy(rp))
+ assert results == rows
+
+ def test_iter_result_proxy_fetchmany(self):
+ """Test iter_result_proxy using fetchmany with a specified step size."""
+ rows = [(5, 6), (7, 8)]
+ rp = DummyResultProxy(chunks=[rows], use_fetchall=False)
+ results = list(iter_result_proxy(rp, step=1))
+ # With step=1, fetchmany returns one row per call.
+ assert results == [rows[0]]
+
+ def test_make_sqlite_url_no_params(self):
+ """Test make_sqlite_url without extra parameters."""
+ path = "test.db"
+ url = make_sqlite_url(path)
+ assert url == "sqlite:///" + path
+
+ def test_make_sqlite_url_with_params(self):
+ """Test make_sqlite_url when using various extra parameters."""
+ path = "test.db"
+ url = make_sqlite_url(path, cache="shared", timeout=30, mode="rw", check_same_thread=False, immutable=True, nolock=True)
+ parsed = urlparse(url)
+ qs = parse_qs(parsed.query)
+ assert qs.get("cache") == ["shared"]
+ assert qs.get("timeout") == ["30"]
+ assert qs.get("mode") == ["rw"]
+ assert qs.get("nolock") == ["1"]
+ assert qs.get("immutable") == ["1"]
+ assert qs.get("check_same_thread") == ["false"]
+ assert qs.get("uri") == ["true"]
+
+ def test_resultiter_iteration(self):
+ """Test that ResultIter correctly wraps the cursor and converts rows."""
+ rows = [tuple([1, 2]), tuple([3, 4])]
+ Row = namedtuple("Row", ["a", "b"])
+ rows = [Row(1, 2), Row(3, 4)]
+ cursor = DummyCursor(rows)
+ result_iter = ResultIter(cursor, row_type=OrderedDict, step=1)
+ result_list = list(result_iter)
+ expected = [OrderedDict([('a', 1), ('b', 2)]), OrderedDict([('a', 3), ('b', 4)])]
+ assert result_list == expected
+ # Verify that the cursor is closed after iteration.
+ assert cursor.closed is True
+
+ def test_resultiter_closed_cursor(self):
+ """Test that ResultIter handles a cursor that raises ResourceClosedError."""
+ cursor = DummyCursorClosed()
+ result_iter = ResultIter(cursor, row_type=OrderedDict, step=1)
+ result_list = list(result_iter)
+ assert result_list == []
+
+ def test_normalize_column_name_valid(self):
+ """Test normalize_column_name with a valid input string."""
+ name = " valid_column "
+ normalized = normalize_column_name(name)
+ assert normalized == "valid_column"
+
+ def test_normalize_column_name_invalid_type(self):
+ """Test that normalize_column_name raises ValueError when input is not a string."""
+ with pytest.raises(ValueError):
+ normalize_column_name(123)
+
+ def test_normalize_column_name_invalid_chars(self):
+ """Test normalize_column_name raises ValueError when name contains invalid characters."""
+ with pytest.raises(ValueError):
+ normalize_column_name("invalid.column")
+ with pytest.raises(ValueError):
+ normalize_column_name("invalid-column")
+
+ def test_normalize_column_name_truncate(self):
+ """Test that normalize_column_name properly truncates overly long names."""
+ long_name = "a" * 70
+ normalized = normalize_column_name(long_name)
+ # Ensure that the UTF-8 encoded byte length is less than 64.
+ assert len(normalized.encode("utf-8")) < 64
+
+ def test_normalize_column_key(self):
+ """Test normalize_column_key to verify trimming and uppercase conversion."""
+ key = " col Name "
+ normalized = normalize_column_key(key)
+ assert normalized == "COLNAME"
+ assert normalize_column_key(None) is None
+
+ def test_normalize_table_name_valid(self):
+ """Test normalize_table_name with a valid input."""
+ name = " my_table "
+ normalized = normalize_table_name(name)
+ assert normalized == "my_table"
+
+ def test_normalize_table_name_invalid(self):
+ """Test normalize_table_name raises ValueError for invalid table names."""
+ with pytest.raises(ValueError):
+ normalize_table_name("")
+ with pytest.raises(ValueError):
+ normalize_table_name(456)
+
+ def test_safe_url_with_password(self):
+ """Test that safe_url removes the password from the URL."""
+ url = "postgresql://user:secret@localhost:5432/db"
+ safe = safe_url(url)
+ assert "secret" not in safe
+ assert "*****" in safe
+
+ def test_safe_url_without_password(self):
+ """Test that safe_url returns the URL unchanged if no password is present."""
+ url = "postgresql://user@localhost:5432/db"
+ safe = safe_url(url)
+ assert safe == url
+
+ def test_index_name(self):
+ """Test that index_name produces the correct artificial index name."""
+ idx = index_name("table", ["col1", "col2"])
+ assert idx.startswith("ix_table_")
+ parts = idx.split("_")
+ assert len(parts[-1]) == 16
+
+ def test_pad_chunk_columns(self):
+ """Test that pad_chunk_columns properly pads records with missing columns."""
+ chunk = [{"a": 1}, {"b": 2}]
+ columns = ["a", "b", "c"]
+ padded = pad_chunk_columns(chunk, columns)
+ for record in padded:
+ for col in columns:
+ assert col in record
+
+ def test_dataset_exception(self):
+ """Test that DatasetException can be raised and caught."""
+ with pytest.raises(DatasetException):
+ raise DatasetException("Test error")
+ def test_iter_result_proxy_empty(self):
+ """Test iter_result_proxy yields nothing when there are no rows."""
+ rp = DummyResultProxy(chunks=[], use_fetchall=True)
+ results = list(iter_result_proxy(rp))
+ assert results == []
+
+ def test_resultiter_manual_close(self):
+ """Test that manually calling close on a ResultIter properly closes the cursor."""
+ cursor = DummyCursor([])
+ result_iter = ResultIter(cursor, row_type=OrderedDict, step=1)
+ result_iter.close()
+ assert cursor.closed is True
+
+ def test_make_sqlite_url_invalid_cache(self):
+ """Test that make_sqlite_url raises an AssertionError for an invalid cache value."""
+ with pytest.raises(AssertionError):
+ make_sqlite_url("test.db", cache="invalid")
+
+ def test_make_sqlite_url_invalid_mode(self):
+ """Test that make_sqlite_url raises an AssertionError for an invalid mode value."""
+ with pytest.raises(AssertionError):
+ make_sqlite_url("test.db", mode="invalid")
+
+ def test_resultiter_iter_method(self):
+ """Test that __iter__ of ResultIter returns the instance itself."""
+ cursor = DummyCursor([])
+ result_iter = ResultIter(cursor, row_type=OrderedDict, step=1)
+ assert result_iter.__iter__() is result_iter
+
+ def test_convert_row_invalid(self):
+ """Test that convert_row raises an error when the row does not have a _fields attribute."""
+ with pytest.raises(AttributeError):
+ convert_row(OrderedDict, (1, 2))
+
+ def test_normalize_column_key_empty(self):
+ """Test that normalize_column_key returns an empty string when input is only whitespace."""
+ normalized = normalize_column_key(" ")
+ assert normalized == ""
+
+ def test_index_name_consistency(self):
+ """Test that index_name returns the same result when called twice with the same inputs, and different when columns order changes."""
+ name1 = index_name("table", ["col1", "col2"])
+ name2 = index_name("table", ["col1", "col2"])
+ assert name1 == name2
+ name3 = index_name("table", ["col2", "col1"])
+ assert name1 != name3
+
+ def test_pad_chunk_columns_already_full(self):
+ """Test that pad_chunk_columns does not modify records that already contain all columns."""
+ chunk = [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]
+ columns = ["a", "b", "c"]
+ padded = pad_chunk_columns(chunk, columns)
+ assert padded == chunk
+
+ def test_pad_chunk_columns_empty_chunk(self):
+ """Test that pad_chunk_columns returns an empty list when given an empty chunk."""
+ padded = pad_chunk_columns([], ["a", "b"])
+ assert padded == []
\ No newline at end of file