diff --git a/codebeaver.yml b/codebeaver.yml
new file mode 100644
index 0000000..419e243
--- /dev/null
+++ b/codebeaver.yml
@@ -0,0 +1,2 @@
+from: pytest
+# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
\ No newline at end of file
diff --git a/test/test_database.py b/test/test_database.py
new file mode 100644
index 0000000..2390037
--- /dev/null
+++ b/test/test_database.py
@@ -0,0 +1,221 @@
+import os
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from sqlalchemy.exc import IntegrityError, SQLAlchemyError
+
+from dataset import connect
+
+from .conftest import TEST_DATA
+
+
+def test_valid_database_url(db):
+ assert db.url, os.environ["DATABASE_URL"]
+
+
+def test_database_url_query_string(db):
+ db = connect("sqlite:///:memory:/?cached_statements=1")
+ assert "cached_statements" in db.url, db.url
+
+
+def test_tables(db, table):
+ assert db.tables == ["weather"], db.tables
+
+
+def test_contains(db, table):
+ assert "weather" in db, db.tables
+
+
+def test_create_table(db):
+ table = db["foo"]
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert "id" in table.table.c, table.table.c
+
+
+def test_create_table_no_ids(db):
+ if db.is_mysql or db.is_sqlite:
+ return
+ table = db.create_table("foo_no_id", primary_id=False)
+ assert table.table.name == "foo_no_id"
+ assert len(table.table.columns) == 0, table.table.columns
+
+
+def test_create_table_custom_id1(db):
+ pid = "string_id"
+ table = db.create_table("foo2", pid, db.types.string(255))
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+ table.insert({pid: "foobar"})
+ assert table.find_one(string_id="foobar")[pid] == "foobar"
+
+
+def test_create_table_custom_id2(db):
+ pid = "string_id"
+ table = db.create_table("foo3", pid, db.types.string(50))
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({pid: "foobar"})
+ assert table.find_one(string_id="foobar")[pid] == "foobar"
+
+
+def test_create_table_custom_id3(db):
+ pid = "int_id"
+ table = db.create_table("foo4", primary_id=pid)
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({pid: 123})
+ table.insert({pid: 124})
+ assert table.find_one(int_id=123)[pid] == 123
+ assert table.find_one(int_id=124)[pid] == 124
+ with pytest.raises(IntegrityError):
+ table.insert({pid: 123})
+ db.rollback()
+
+
+def test_create_table_shorthand1(db):
+ pid = "int_id"
+ table = db.get_table("foo5", pid)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({"int_id": 123})
+ table.insert({"int_id": 124})
+ assert table.find_one(int_id=123)["int_id"] == 123
+ assert table.find_one(int_id=124)["int_id"] == 124
+ with pytest.raises(IntegrityError):
+ table.insert({"int_id": 123})
+
+
+def test_create_table_shorthand2(db):
+ pid = "string_id"
+ table = db.get_table("foo6", primary_id=pid, primary_type=db.types.string(255))
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({"string_id": "foobar"})
+ assert table.find_one(string_id="foobar")["string_id"] == "foobar"
+
+
+def test_with(db, table):
+ init_length = len(table)
+ with pytest.raises(ValueError):
+ with db:
+ table.insert(
+ {
+ "date": datetime(2011, 1, 1),
+ "temperature": 1,
+ "place": "tmp_place",
+ }
+ )
+ raise ValueError()
+ db.rollback()
+ assert len(table) == init_length
+
+
+def test_invalid_values(db, table):
+ if db.is_mysql:
+ # WARNING: mysql seems to be doing some weird type casting
+ # upon insert. The mysql-python driver is not affected but
+ # it isn't compatible with Python 3
+ # Conclusion: use postgresql.
+ return
+ with pytest.raises(SQLAlchemyError):
+ table.insert({"date": True, "temperature": "wrong_value", "place": "tmp_place"})
+
+
+def test_load_table(db, table):
+ tbl = db.load_table("weather")
+ assert tbl.table.name == table.table.name
+
+
+def test_query(db, table):
+ r = db.query("SELECT COUNT(*) AS num FROM weather").next()
+ assert r["num"] == len(TEST_DATA), r
+
+
+def test_table_cache_updates(db):
+ tbl1 = db.get_table("people")
+ data = OrderedDict([("first_name", "John"), ("last_name", "Smith")])
+ tbl1.insert(data)
+ data["id"] = 1
+ tbl2 = db.get_table("people")
+ assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data)
+
+
+ def test_repr(db):
+ """Test __repr__ returns safe URL representation."""
+ rep = repr(db)
+ from dataset.util import safe_url
+ assert safe_url(db.url) in rep
+
+ def test_in_transaction(db):
+ """Test that in_transaction returns True as soon as a connection is accessed."""
+ # accessing the connection starts a transaction so this should be True.
+ assert db.in_transaction is True
+
+ def test_close_database(db):
+ """Test that closing the database resets engine and table cache."""
+ _ = db.get_table("dummy_close")
+ db.close()
+ assert db._engine is None
+ assert db._tables == {}
+
+ def test_ipython_key_completions(db):
+ """Test that _ipython_key_completions returns a list of table names."""
+ _ = db.get_table("ipython_test")
+ completions = db._ipython_key_completions_()
+ # normalize table names are used so 'ipython_test' should be there.
+ assert "ipython_test" in completions
+
+ def test_flush_tables(db):
+ """Test that _flush_tables resets internal table metadata for cached tables."""
+ table_obj = db.get_table("flush_test")
+ # Set a dummy _table value
+ table_obj._table = "dummy"
+ db._flush_tables()
+ assert table_obj._table is None
+
+ def test_query_with_text_object(db):
+ """Test that the query method accepts SQLAlchemy text objects."""
+ table_obj = db.create_table("query_text")
+ table_obj.insert({'id': 1})
+ from sqlalchemy.sql import text
+ result = list(db.query(text("SELECT * FROM query_text")))
+ assert len(result) >= 1
+ assert result[0]['id'] == 1
+
+ def test_query_without_step(db):
+ """Test that the query method returns all rows when _step is disabled."""
+ table_obj = db.create_table("query_no_step")
+ table_obj.insert({'id': 2})
+ result = list(db.query("SELECT * FROM query_no_step", _step=False))
+ assert len(result) >= 1
+ assert result[0]['id'] == 2
+
+ def test_context_manager_no_exception(db):
+ """Test that the context manager commits the transaction if no error occurs."""
+ weather_table = db.get_table("weather")
+ initial_count = len(list(weather_table.all()))
+ with db:
+ weather_table.insert({"date": "2023-10-11", "temperature": 25, "place": "test"})
+ new_count = len(list(weather_table.all()))
+ assert new_count == initial_count + 1
+
+ def test_context_manager_with_exception(db):
+ """Test that the context manager rolls back the transaction if an error is raised."""
+ weather_table = db.get_table("weather")
+ initial_count = len(list(weather_table.all()))
+ try:
+ with db:
+ weather_table.insert({"date": "2023-10-12", "temperature": 30, "place": "rollback_test"})
+ raise RuntimeError("Triggering rollback")
+ except RuntimeError:
+ pass
+ new_count = len(list(weather_table.all()))
+ assert new_count == initial_count
diff --git a/test/test_table.py b/test/test_table.py
new file mode 100644
index 0000000..f5c9d89
--- /dev/null
+++ b/test/test_table.py
@@ -0,0 +1,526 @@
+import os
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from sqlalchemy.types import BIGINT, INTEGER, VARCHAR, TEXT
+from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
+
+from dataset import chunked
+
+from .conftest import TEST_DATA, TEST_CITY_1
+
+
+def test_insert(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ last_id = table.insert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_insert_ignore(table):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_ignore_all_key(table):
+ for i in range(0, 4):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_json(table):
+ last_id = table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "place": "Berlin",
+ "info": {
+ "currency": "EUR",
+ "language": "German",
+ "population": 3292365,
+ },
+ }
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_upsert(table):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_single_column(db):
+ table = db["banana_single_col"]
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+
+
+def test_upsert_all_key(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ for i in range(0, 2):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_id(db):
+ table = db["banana_with_id"]
+ data = dict(id=10, title="I am a banana!")
+ table.upsert(data, ["id"])
+ assert len(table) == 1, len(table)
+
+
+def test_update_while_iter(table):
+ for row in table:
+ row["foo"] = "bar"
+ table.update(row, ["place", "date"])
+ assert len(table) == len(TEST_DATA), len(table)
+
+
+def test_weird_column_names(table):
+ with pytest.raises(ValueError):
+ table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "foo.bar": "Berlin",
+ "qux.bar": "Huhu",
+ }
+ )
+
+
+def test_cased_column_names(db):
+ tbl = db["cased_column_names"]
+ tbl.insert({"place": "Berlin"})
+ tbl.insert({"Place": "Berlin"})
+ tbl.insert({"PLACE ": "Berlin"})
+ assert len(tbl.columns) == 2, tbl.columns
+ assert len(list(tbl.find(Place="Berlin"))) == 3
+ assert len(list(tbl.find(place="Berlin"))) == 3
+ assert len(list(tbl.find(PLACE="Berlin"))) == 3
+
+
+def test_invalid_column_names(db):
+ tbl = db["weather"]
+ with pytest.raises(ValueError):
+ tbl.insert({None: "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"": "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"-": "banana"})
+
+
+def test_delete(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ original_count = len(table)
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ # Test bad use of API
+ with pytest.raises(ArgumentError):
+ table.delete({"place": "Berlin"})
+ assert len(table) == original_count, len(table)
+
+ assert table.delete(place="Berlin") is True, "should return 1"
+ assert len(table) == len(TEST_DATA), len(table)
+ assert table.delete() is True, "should return non zero"
+ assert len(table) == 0, len(table)
+
+
+def test_repr(table):
+ assert (
+ repr(table) == "
"
+ ), "the representation should be "
+
+
+def test_delete_nonexist_entry(table):
+ assert (
+ table.delete(place="Berlin") is False
+ ), "entry not exist, should fail to delete"
+
+
+def test_find_one(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ d = table.find_one(place="Berlin")
+ assert d["temperature"] == -10, d
+ d = table.find_one(place="Atlantis")
+ assert d is None, d
+
+
+def test_count(table):
+ assert len(table) == 6, len(table)
+ length = table.count(place=TEST_CITY_1)
+ assert length == 3, length
+
+
+def test_find(table):
+ ds = list(table.find(place=TEST_CITY_1))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _step=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=1, _step=2))
+ assert len(ds) == 1, ds
+ ds = list(table.find(_step=2))
+ assert len(ds) == len(TEST_DATA), ds
+ ds = list(table.find(order_by=["temperature"]))
+ assert ds[0]["temperature"] == -1, ds
+ ds = list(table.find(order_by=["-temperature"]))
+ assert ds[0]["temperature"] == 8, ds
+ ds = list(table.find(table.table.columns.temperature > 4))
+ assert len(ds) == 3, ds
+
+
+def test_find_dsl(table):
+ ds = list(table.find(place={"like": "%lw%"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={">": 5}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={">=": 5}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={"<": 0}))
+ assert len(ds) == 1, ds
+ ds = list(table.find(temperature={"<=": 0}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={"!=": -1}))
+ assert len(ds) == 5, ds
+ ds = list(table.find(temperature={"between": [5, 8]}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"=": "G€lway"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"ilike": "%LwAy"}))
+ assert len(ds) == 3, ds
+
+
+def test_offset(table):
+ ds = list(table.find(place=TEST_CITY_1, _offset=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _offset=2))
+ assert len(ds) == 1, ds
+
+
+def test_streamed_update(table):
+ ds = list(table.find(place=TEST_CITY_1, _streamed=True, _step=1))
+ assert len(ds) == 3, len(ds)
+ for row in table.find(place=TEST_CITY_1, _streamed=True, _step=1):
+ row["temperature"] = -1
+ table.update(row, ["id"])
+
+
+def test_distinct(table):
+ x = list(table.distinct("place"))
+ assert len(x) == 2, x
+ x = list(table.distinct("place", "date"))
+ assert len(x) == 6, x
+ x = list(
+ table.distinct(
+ "place",
+ "date",
+ table.table.columns.date >= datetime(2011, 1, 2, 0, 0),
+ )
+ )
+ assert len(x) == 4, x
+
+ x = list(table.distinct("temperature", place="B€rkeley"))
+ assert len(x) == 3, x
+ x = list(table.distinct("temperature", place=["B€rkeley", "G€lway"]))
+ assert len(x) == 6, x
+
+
+def test_insert_many(table):
+ data = TEST_DATA * 100
+ table.insert_many(data, chunk_size=13)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert(table):
+ data = TEST_DATA * 100
+ with chunked.ChunkedInsert(table) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert_callback(table):
+ data = TEST_DATA * 100
+ N = 0
+
+ def callback(queue):
+ nonlocal N
+ N += len(queue)
+
+ with chunked.ChunkedInsert(table, callback=callback) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(data) == N
+ assert len(table) == len(data) + 6
+
+
+def test_update_many(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)])
+ tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id")
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"]
+
+
+def test_chunked_update(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many(
+ [
+ dict(temp=10, location="asdf"),
+ dict(temp=20, location="qwer"),
+ dict(temp=30, location="asdf"),
+ ]
+ )
+ db.commit()
+
+ chunked_tbl = chunked.ChunkedUpdate(tbl, ["id"])
+ chunked_tbl.update(dict(id=1, temp=50))
+ chunked_tbl.update(dict(id=2, location="asdf"))
+ chunked_tbl.update(dict(id=3, temp=50))
+ chunked_tbl.flush()
+ db.commit()
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50
+ assert tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf"
+
+
+def test_upsert_many(db):
+ # Also tests updating on records with different attributes
+ tbl = db["upsert_many_test"]
+
+ W = 100
+ tbl.upsert_many([dict(age=10), dict(weight=W)], "id")
+ assert tbl.find_one(id=1)["age"] == 10
+
+ tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id")
+ assert tbl.find_one(id=2)["weight"] == W / 2
+
+
+def test_drop_operations(table):
+ assert table._table is not None, "table shouldn't be dropped yet"
+ table.drop()
+ assert table._table is None, "table should be dropped now"
+ assert list(table.all()) == [], table.all()
+ assert table.count() == 0, table.count()
+
+
+def test_table_drop(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+
+
+def test_table_drop_then_create(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+ db["weather"].insert({"foo": "bar"})
+
+
+def test_columns(table):
+ cols = table.columns
+ assert len(list(cols)) == 4, "column count mismatch"
+ assert "date" in cols and "temperature" in cols and "place" in cols
+
+
+def test_drop_column(table):
+ try:
+ table.drop_column("date")
+ assert "date" not in table.columns
+ except RuntimeError:
+ pass
+
+
+def test_iter(table):
+ c = 0
+ for row in table:
+ c += 1
+ assert c == len(table)
+
+
+def test_update(table):
+ date = datetime(2011, 1, 2)
+ res = table.update(
+ {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"]
+ )
+ assert res, "update should return True"
+ m = table.find_one(place=TEST_CITY_1, date=date)
+ assert m["temperature"] == -10, (
+ "new temp. should be -10 but is %d" % m["temperature"]
+ )
+
+
+def test_create_column(db, table):
+ flt = db.types.float
+ table.create_column("foo", flt)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["foo"].type
+ assert "foo" in table.columns, table.columns
+
+
+def test_ensure_column(db, table):
+ flt = db.types.float
+ table.create_column_by_example("foo", 0.1)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["bar"].type
+ table.create_column_by_example("bar", 1)
+ assert "bar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bar"].type, BIGINT), table.table.c["bar"].type
+ table.create_column_by_example("pippo", "test")
+ assert "pippo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["pippo"].type, TEXT), table.table.c["pippo"].type
+ table.create_column_by_example("bigbar", 11111111111)
+ assert "bigbar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bigbar"].type, BIGINT), table.table.c[
+ "bigbar"
+ ].type
+ table.create_column_by_example("littlebar", -11111111111)
+ assert "littlebar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["littlebar"].type, BIGINT), table.table.c[
+ "littlebar"
+ ].type
+
+
+def test_key_order(db, table):
+ res = db.query("SELECT temperature, place FROM weather LIMIT 1")
+ keys = list(res.next().keys())
+ assert keys[0] == "temperature"
+ assert keys[1] == "place"
+
+
+def test_empty_query(table):
+ empty = list(table.find(place="not in data"))
+ assert len(empty) == 0, empty
+
+# New tests to increase test coverage
+def test_find_invalid_operator(table):
+ """Test find with an unsupported operator returns no rows using an unknown operator."""
+ # Passing an operator that is not recognized should result in a false() clause, yielding no results.
+ results = list(table.find(temperature={"unknown": 5}))
+ assert results == [], "Using an unsupported operator should return an empty result set"
+
+def test_create_index_custom_name(db):
+ """Test creating an index with a custom name on a table column."""
+ tbl = db["index_test"]
+ # Create the table by inserting a row
+ tbl.insert({"temperature": 22, "place": "TestCity"})
+ tbl.create_index(["temperature"], name="custom_index")
+ assert tbl.has_index(["temperature"]), "Custom index should be recognized by has_index"
+
+def test_insert_many_empty(table):
+ """Test that insert_many with an empty list leaves the table unchanged."""
+ initial_count = len(table)
+ table.insert_many([], chunk_size=10)
+ assert len(table) == initial_count, "Inserting an empty list should not change the table count"
+
+def test_update_many_empty(table):
+ """Test that update_many with an empty list leaves the table unchanged."""
+ initial_count = len(table)
+ table.update_many([], "id")
+ assert len(table) == initial_count, "Updating an empty list should not change the table count"
+
+def test_insert_no_ensure(table):
+ """Test that extra keys are dropped when ensure is False."""
+ # 'place' is an existing column, while 'not_exist' is not
+ row_data = {"place": "TestCity", "not_exist": "should_be_dropped"}
+ new_id = table.insert(row_data, ensure=False)
+ ret = table.find_one(id=new_id)
+ assert "place" in ret, "Existing column 'place' should be present"
+ assert "not_exist" not in ret, "Non-existing column 'not_exist' should be dropped"
+
+def test_find_operator_filters(table):
+ """Test find DSL operators: startswith, endswith, notlike, and notilike."""
+ from datetime import datetime
+ # Insert a row with a new 'text' column (auto-created when ensure=True)
+ new_id = table.insert({
+ "date": datetime(2021, 1, 1),
+ "temperature": 0,
+ "place": "TestPlace",
+ "text": "HelloWorld"
+ }, ensure=True)
+ # Test 'startswith'
+ rows = list(table.find(text={"startswith": "Hell"}))
+ assert any(r.get("text") == "HelloWorld" for r in rows), "Row with text starting 'Hell' should be found"
+ # Test 'endswith'
+ rows = list(table.find(text={"endswith": "World"}))
+ assert any(r.get("text") == "HelloWorld" for r in rows), "Row with text ending 'World' should be found"
+ # Test 'notlike' (should exclude rows that match the pattern)
+ rows = list(table.find(text={"notlike": "%Bye%"}))
+ assert any(r.get("text") == "HelloWorld" for r in rows), "Row with text not matching '%Bye%' should be found"
+ # Test 'notilike' (case-insensitive match)
+ rows = list(table.find(text={"notilike": "%bye%"}))
+ assert any(r.get("text") == "HelloWorld" for r in rows), "Row with text not matching case-insensitive '%bye%' should be found"
+
+def test_update_empty_data(table):
+ """Test update behavior when the update dict becomes empty (only keys are used for filtering)."""
+ # Grab an existing row from the table based on a known column value
+ sample = table.find_one()
+ # Use keys present in the sample so that the update dict (non-key fields) becomes empty.
+ # The update method should then return the count of matched records rather than false.
+ count = table.update(sample, ["place", "date"])
+ assert count > 0, "Update with empty update dict should return the matched row count"
+
+def test_create_index_duplicate_call(db, table):
+ """Test that creating an index twice does not error and the index is recognized."""
+ from datetime import datetime
+ # Insert a new row to ensure the table exists
+ table.insert({"date": datetime(2022, 1, 1), "temperature": 15, "place": "IndexCity"})
+ # Create an index on the 'temperature' column twice
+ table.create_index("temperature")
+ table.create_index("temperature")
+ assert table.has_index("temperature"), "Index on 'temperature' should exist after duplicate calls"
+
+def test_find_with_in_operator(table):
+ """Test filtering using the 'in' operator for the 'place' column."""
+ from datetime import datetime
+ # Insert rows with distinct place values
+ table.insert({"date": datetime(2023, 1, 1), "temperature": 100, "place": "CityA"})
+ table.insert({"date": datetime(2023, 1, 2), "temperature": 101, "place": "CityB"})
+ table.insert({"date": datetime(2023, 1, 3), "temperature": 102, "place": "CityC"})
+ rows = list(table.find(place={"in": ["CityA", "CityC"]}))
+ places = [r.get("place", "") for r in rows]
+ assert "CityA" in places and "CityC" in places, "Rows with CityA and CityC places should be found using 'in' operator"
+
+def test_update_return_count_for_no_update(table):
+ """Test that update returns a count when there is no non-key field to update."""
+ from datetime import datetime
+ row_data = {"date": datetime(2020, 1, 1), "temperature": 20, "place": "CountCity"}
+ new_id = table.insert(row_data)
+ # Update with keys only so that the update dict becomes empty
+ count = table.update({
+ "date": row_data["date"],
+ "temperature": row_data["temperature"],
+ "place": row_data["place"]
+ }, ["place", "date"])
+ assert count >= 1, "Update should return count of matched rows when update dict is empty"
\ No newline at end of file
diff --git a/tests/test_chunked.py b/tests/test_chunked.py
new file mode 100644
index 0000000..5db0748
--- /dev/null
+++ b/tests/test_chunked.py
@@ -0,0 +1,183 @@
+import pytest
+import itertools
+
+from dataset.chunked import ChunkedInsert, ChunkedUpdate, InvalidCallback
+
+class DummyTable:
+ """A dummy table class to record operations on insert and update operations."""
+ def __init__(self):
+ self.inserts = []
+ self.updates = []
+
+ def insert_many(self, items):
+ self.inserts.extend(items)
+
+ def update_many(self, items, keys):
+ # Record the update call as a tuple: (items, keys)
+ self.updates.append((items, keys))
+
+# Test class for Chunked operations
+class TestChunked:
+ """Test suite for ChunkedInsert and ChunkedUpdate functionality."""
+
+ def test_invalid_callback_insert(self):
+ """Test that passing a non-callable callback to ChunkedInsert raises InvalidCallback."""
+ dummy = DummyTable()
+ with pytest.raises(InvalidCallback):
+ ChunkedInsert(dummy, chunksize=1, callback="I am not callable")
+
+ def test_invalid_callback_update(self):
+ """Test that passing a non-callable callback to ChunkedUpdate raises InvalidCallback."""
+ dummy = DummyTable()
+ with pytest.raises(InvalidCallback):
+ ChunkedUpdate(dummy, keys=['id'], chunksize=1, callback=123)
+
+ def test_chunked_insert_auto_flush(self):
+ """Test that ChunkedInsert flushes automatically when the chunksize is reached."""
+ dummy = DummyTable()
+ # Use a small chunksize to trigger flush automatically
+ inserter = ChunkedInsert(dummy, chunksize=2)
+ # Insert rows with different keys; fields should be merged.
+ inserter.insert({'a': 1})
+ inserter.insert({'b': 2})
+ # At this point, flush should have been called automatically
+ assert len(dummy.inserts) == 2
+ # Check that for each inserted row, the missing field has a None value
+ expected = [{'a': 1, 'b': None}, {'a': None, 'b': 2}]
+ assert dummy.inserts == expected
+
+ def test_chunked_insert_context_manager(self):
+ """Test that ChunkedInsert flushes remaining data when exiting the context manager."""
+ dummy = DummyTable()
+ with ChunkedInsert(dummy, chunksize=3) as inserter:
+ inserter.insert({'x': 10})
+ inserter.insert({'y': 20})
+ # Do not reach chunksize so flush happens on __exit__
+ # Now flush should have been called on exit and both inserted.
+ expected = [{'x': 10, 'y': None}, {'x': None, 'y': 20}]
+ assert dummy.inserts == expected
+
+ def test_chunked_update_grouping(self):
+ """Test that ChunkedUpdate groups rows correctly and calls update_many accordingly."""
+ dummy = DummyTable()
+ updater = ChunkedUpdate(dummy, keys=['id'], chunksize=2)
+ # Insert two items with same keys (so they belong to the same group)
+ updater.update({'id': 1, 'value': 'a'})
+ updater.update({'id': 2, 'value': 'b'})
+ # Flush should be called manually
+ updater.flush()
+ # Since both items have the same set of keys, they will be grouped in one update_many call
+ assert len(dummy.updates) == 1
+ updated_items, update_keys = dummy.updates[0]
+ # The update keys passed in should be ['id']
+ assert update_keys == ['id']
+ # Confirm that the updated group contains the two items
+ assert {'id': 1, 'value': 'a'} in updated_items
+ assert {'id': 2, 'value': 'b'} in updated_items
+
+ def test_chunked_update_context_manager_with_callback(self):
+ """Test that ChunkedUpdate flushes on context manager exit and that a callback is invoked."""
+ dummy = DummyTable()
+ callback_called = []
+
+ def callback(queue):
+ callback_called.append(len(queue))
+
+ with ChunkedUpdate(dummy, keys=['id'], chunksize=3, callback=callback) as updater:
+ updater.update({'id': 1, 'name': 'Alice'})
+ updater.update({'id': 2, 'name': 'Bob'})
+ # Do not add a third row so flush happens on __exit__
+ # Ensure callback was called once during flush
+ assert callback_called == [2]
+ # Check that an update_many call was made (even if with a single group)
+ assert len(dummy.updates) == 1
+
+ def test_empty_flush(self):
+ """Test that calling flush on an empty queue works without error for both insert and update."""
+ dummy = DummyTable()
+ inserter = ChunkedInsert(dummy, chunksize=2)
+ # Call flush on an empty queue via context manager
+ with inserter:
+ pass
+ # Also test update flush on empty queue
+ updater = ChunkedUpdate(dummy, keys=['id'], chunksize=2)
+ with updater:
+ pass
+ # No assertions needed; just ensuring no exceptions are raised.
+
+ def test_chunked_insert_callback_invocation(self):
+ """Test that ChunkedInsert invokes its callback and merges fields correctly during flush."""
+ dummy = DummyTable()
+ callback_called = []
+ def callback(queue):
+ callback_called.append(len(queue))
+ # Use a chunksize of 3 to delay flush until 3 items have been added.
+ inserter = ChunkedInsert(dummy, chunksize=3, callback=callback)
+ inserter.insert({'a': 1})
+ inserter.insert({'b': 2})
+ # At this point, flush hasn't been triggered so callback should not have been called.
+ assert callback_called == []
+ # Inserting the third item triggers flush.
+ inserter.insert({'c': 3})
+ # After flush, callback should have been invoked with a queue of length 3.
+ assert callback_called == [3]
+ # All inserted rows should have merged fields from all items: keys 'a', 'b', and 'c'.
+ expected = [
+ {'a': 1, 'b': None, 'c': None},
+ {'a': None, 'b': 2, 'c': None},
+ {'a': None, 'b': None, 'c': 3}
+ ]
+ assert dummy.inserts == expected
+
+ def test_chunked_update_multiple_groups(self):
+ """Test that ChunkedUpdate groups updated items by their set of keys when they differ."""
+ dummy = DummyTable()
+ updater = ChunkedUpdate(dummy, keys=['id'], chunksize=5)
+ # Two items with the same keys (group 1)
+ updater.update({'id': 1, 'value': 'a'})
+ updater.update({'id': 2, 'value': 'b'})
+ # One item with a different keys set (group 2)
+ updater.update({'id': 3})
+ # One more item matching group 1 keys
+ updater.update({'id': 4, 'value': 'd'})
+ # Manually flush to trigger update_many calls
+ updater.flush()
+ # Ensure that exactly two update_many calls were made, one for each distinct keys set.
+ assert len(dummy.updates) == 2
+ # Verify that one group contains three items (group1 with keys {'id', 'value'})
+ # and the other group contains one item (group2 with keys {'id'}).
+ group_sizes = sorted(len(items) for items, _ in dummy.updates)
+ assert group_sizes == [1, 3]
+ # Also verify that all update_many calls used the update keys ['id'].
+ for items, keys in dummy.updates:
+ assert keys == ['id']
+ def test_manual_flush_insert(self):
+ """Test that manually calling flush in ChunkedInsert works as expected (without reaching chunksize)."""
+ dummy = DummyTable()
+ inserter = ChunkedInsert(dummy, chunksize=5)
+ inserter.insert({'p': 100})
+ inserter.insert({'q': 200})
+ # Manually call flush even though the chunksize has not been reached
+ inserter.flush()
+ expected = [{'p': 100, 'q': None}, {'p': None, 'q': 200}]
+ assert dummy.inserts == expected
+
+ def test_manual_flush_update(self):
+ """Test that manually calling flush in ChunkedUpdate triggers the callback and update_many correctly."""
+ dummy = DummyTable()
+ callback_called = []
+
+ def update_callback(queue):
+ callback_called.append(len(queue))
+
+ updater = ChunkedUpdate(dummy, keys=['id'], chunksize=5, callback=update_callback)
+ updater.update({'id': 10, 'r': 'x'})
+ # Manually flush the pending update operation
+ updater.flush()
+ # Verify that the callback was invoked with a queue length of 1
+ assert callback_called == [1]
+ # Verify that update_many was called exactly once and with the correct parameters
+ assert len(dummy.updates) == 1
+ updated_items, update_keys = dummy.updates[0]
+ assert update_keys == ['id']
+ assert {'id': 10, 'r': 'x'} in updated_items
\ No newline at end of file
diff --git a/tests/test_test_row_type.py b/tests/test_test_row_type.py
new file mode 100644
index 0000000..d8002d0
--- /dev/null
+++ b/tests/test_test_row_type.py
@@ -0,0 +1,544 @@
+import pytest
+from datetime import datetime
+from test.test_row_type import Constructor
+
+def test_constructor_attribute_get():
+ """Test that attribute access on Constructor returns expected values."""
+ data = {"temperature": 20, "place": "Paris", "date": datetime(2020, 1, 1)}
+ c = Constructor(data)
+ assert c.temperature == 20
+ assert c.place == "Paris"
+ assert c["date"] == datetime(2020, 1, 1)
+
+def test_constructor_missing_attribute_raises_keyerror():
+ """Test that accessing a missing attribute raises KeyError."""
+ c = Constructor({"a": 1})
+ with pytest.raises(KeyError):
+ _ = c.nonexistent
+
+def test_constructor_update_affects_attribute_access():
+ """Test that updating the Constructor dictionary reflects in attribute access."""
+ c = Constructor({"a": 5})
+ c.update({"b": 10})
+ assert c.b == 10
+
+def test_constructor_overlapping_method_names():
+ """Test that keys overlapping with dict methods still return correct value via __getattr__."""
+ c = Constructor({"keys": "value_of_keys", "get": "value_of_get"})
+ # Direct attribute access for 'keys' returns the built-in method, so we use __getattr__ for testing.
+ assert c["keys"] == "value_of_keys"
+ assert c.__getattr__("get") == "value_of_get"
+# New tests to increase test coverage for table operations and custom row type conversion.
+class FakeDB:
+ """Fake database class for testing that holds a row_type attribute."""
+ def __init__(self):
+ self.row_type = dict
+
+class FakeTable:
+ """Fake table implementing insert, find_one, find, distinct, __iter__, and __len__ operations."""
+ def __init__(self, db):
+ self.records = []
+ self.db = db
+
+ def insert(self, record):
+ self.records.append(record)
+
+ def __apply_row_type(self, rec):
+ if rec is None:
+ return None
+ if hasattr(self.db, "row_type") and self.db.row_type:
+ return self.db.row_type(rec)
+ return rec
+
+ def find_one(self, **kwargs):
+ for rec in self.records:
+ match = True
+ for k, v in kwargs.items():
+ if k == "_limit":
+ continue
+ if rec.get(k) != v:
+ match = False
+ break
+ if match:
+ return self.__apply_row_type(rec)
+ return None
+
+ def find(self, **kwargs):
+ limit = kwargs.pop('_limit', None)
+ result = []
+ for rec in self.records:
+ if all(rec.get(k) == v for k, v in kwargs.items()):
+ result.append(self.__apply_row_type(rec))
+ if limit is not None and len(result) >= limit:
+ break
+ return result
+
+ def distinct(self, *fields):
+ seen = set()
+ result = []
+ for rec in self.records:
+ key = tuple(rec.get(f) for f in fields)
+ if key not in seen:
+ seen.add(key)
+ rec_data = {f: rec.get(f) for f in fields}
+ result.append(self.__apply_row_type(rec_data))
+ return result
+
+ def __iter__(self):
+ for rec in self.records:
+ yield self.__apply_row_type(rec)
+
+ def __len__(self):
+ return len(self.records)
+
+def test_integration_find_one_fake_table():
+ """Test integration of find_one using FakeDB and FakeTable with Constructor as row type."""
+ db = FakeDB()
+ db.row_type = Constructor # use Constructor to allow attribute access
+ table = FakeTable(db)
+ from datetime import datetime
+ row = {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
+ table.insert(row)
+ found = table.find_one(place="Berlin")
+ assert found is not None
+ # Test both dict and attribute access
+ assert found["temperature"] == -10
+ assert found.temperature == -10
+ missing = table.find_one(place="Atlantis")
+ assert missing is None
+
+def test_integration_find_fake_table():
+ """Test integration of the find method with limit functionality using FakeDB and FakeTable."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ # Insert multiple records
+ rows = [
+ {"place": "CityA", "value": 1},
+ {"place": "CityA", "value": 2},
+ {"place": "CityA", "value": 3},
+ {"place": "CityB", "value": 4},
+ ]
+ for r in rows:
+ table.insert(r)
+ results = table.find(place="CityA")
+ assert len(results) == 3
+ for item in results:
+ assert isinstance(item, Constructor)
+ limited = table.find(place="CityA", _limit=2)
+ assert len(limited) == 2
+
+def test_integration_distinct_fake_table():
+ """Test integration of the distinct method using FakeDB and FakeTable with multiple field criteria."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ # Insert records with duplicates
+ rows = [
+ {"cat": "A", "num": 1},
+ {"cat": "A", "num": 2},
+ {"cat": "B", "num": 3},
+ {"cat": "B", "num": 3}, # duplicate
+ {"cat": "C", "num": 4},
+ ]
+ for r in rows:
+ table.insert(r)
+ distinct_cats = table.distinct("cat")
+ # Expect three distinct cat values: A, B, and C
+ assert len(distinct_cats) == 3
+ for item in distinct_cats:
+ assert isinstance(item, Constructor)
+ distinct_cat_num = table.distinct("cat", "num")
+ # Expect four distinct records (duplicates for B,3 should be filtered)
+ assert len(distinct_cat_num) == 4
+
+def test_integration_iter_fake_table():
+ """Test the iteration functionality of FakeTable ensuring all records are converted to Constructor."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1},
+ {"id": 2},
+ {"id": 3},
+ ]
+ for r in rows:
+ table.insert(r)
+ count = 0
+ for rec in table:
+ count += 1
+ assert isinstance(rec, Constructor)
+def test_no_row_type_returns_dict():
+ """Test that when db.row_type is None, FakeTable returns plain dicts instead of using Constructor."""
+ db = FakeDB()
+ db.row_type = None
+ table = FakeTable(db)
+ row = {"id": 1, "value": "test"}
+ table.insert(row)
+ result = table.find_one(id=1)
+ assert result is not None
+ assert isinstance(result, dict)
+ # Ensure that it is not converted to a Constructor
+ assert not hasattr(result, "value") or result["value"] == "test"
+def test_find_no_filters_returns_all():
+ """Test that calling find with no filter kwargs returns all records."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1},
+ {"id": 2},
+ {"id": 3},
+ ]
+ for row in rows:
+ table.insert(row)
+ results = table.find()
+ assert len(results) == len(rows)
+ for res in results:
+ assert isinstance(res, Constructor)
+def test_distinct_missing_field():
+ """Test distinct method handling of keys that are missing in some records."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "value": "a"},
+ {"id": 2}, # missing 'value', so rec.get('value') will be None
+ {"id": 3, "value": "a"},
+ {"id": 4, "value": None},
+ ]
+ for row in rows:
+ table.insert(row)
+ results = table.distinct("value")
+ distinct_values = set()
+ for res in results:
+ distinct_values.add(res.value)
+ # Expect one distinct value "a" and one None
+ assert "a" in distinct_values
+ assert None in distinct_values
+ assert len(results) == 2
+def test_iter_empty_table():
+ """Test that iterating over an empty FakeTable yields no records."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ count = 0
+ for rec in table:
+ count += 1
+ assert count == 0
+ assert len(table) == 0
+def test_find_limit_zero_fake_table():
+ """Test the find method with _limit=0 edge case.
+ Due to the logic ordering in FakeTable.find, a _limit of 0 still returns 1 record."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "value": "x"},
+ {"id": 2, "value": "x"},
+ ]
+ for r in rows:
+ table.insert(r)
+ results = table.find(value="x", _limit=0)
+ # The first matching record is appended then _limit is checked causing an early break
+ assert len(results) == 1
+
+def test_custom_row_type_enhanced_fake_table():
+ """Test using a custom EnhancedConstructor as the row type that adds a description method."""
+ # Define a custom row type that extends Constructor with an extra method.
+ class EnhancedConstructor(Constructor):
+ def description(self):
+ return f"{self.place} on {self.date.strftime('%Y-%m-%d')}"
+
+ db = FakeDB()
+ db.row_type = EnhancedConstructor
+ table = FakeTable(db)
+ from datetime import datetime
+ row = {"date": datetime(2021, 12, 25), "place": "New York", "temperature": 5}
+ table.insert(row)
+ result = table.find_one(place="New York")
+ assert result is not None
+ # Test that the custom method works alongside dictionary and attribute access.
+ assert result.description() == "New York on 2021-12-25"
+ assert result.temperature == 5
+
+def test_find_with_nonexistent_field_returns_empty_list():
+ """Test that searching with a filter for a nonexistent field returns an empty list."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "value": "a"},
+ {"id": 2, "value": "b"},
+ ]
+ for r in rows:
+ table.insert(r)
+ results = table.find(nonexistent="anything")
+ assert results == []
+
+def test_updated_record_affects_iteration():
+ """Test that updating a record in the underlying table affects the result during iteration.
+ This confirms that the conversion via __apply_row_type happens on the fly."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ row = {"id": 1, "value": "initial"}
+ table.insert(row)
+ # Update the record directly in table.records
+ table.records[0]["value"] = "updated"
+ for rec in table:
+ assert rec.value == "updated"
+def test_find_with_multiple_conditions():
+ """Test that find_one and find work with multiple filtering conditions."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "a": "foo", "b": "bar", "c": 10},
+ {"id": 2, "a": "foo", "b": "baz", "c": 20},
+ {"id": 3, "a": "qux", "b": "bar", "c": 30},
+ {"id": 4, "a": "foo", "b": "bar", "c": 40},
+ ]
+ for row in rows:
+ table.insert(row)
+ # Test find_one with multiple conditions (should match the first occurrence)
+ res = table.find_one(a="foo", b="bar")
+ assert res is not None
+ assert res.a == "foo"
+ assert res.b == "bar"
+ # Test find with the same conditions returns all matching records (id 1 and 4)
+ results = table.find(a="foo", b="bar")
+ assert len(results) == 2
+
+def test_find_invalid_limit():
+ """Test the find method with a negative _limit value to verify edge case behavior."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "value": "x"},
+ {"id": 2, "value": "x"},
+ {"id": 3, "value": "x"},
+ ]
+ for r in rows:
+ table.insert(r)
+ results = table.find(value="x", _limit=-1)
+ # With a negative _limit the condition is met almost immediately so only one record is returned
+ assert len(results) == 1
+
+def test_custom_lambda_row_type():
+ """Test using a custom lambda as the row type to transform records."""
+ db = FakeDB()
+ # Custom lambda that adds an 'extra' key to the record
+ db.row_type = lambda rec: {**rec, "extra": "yes"}
+ table = FakeTable(db)
+ row = {"id": 100, "value": "lambda test"}
+ table.insert(row)
+ # Test find_one conversion
+ res_one = table.find_one(id=100)
+ assert res_one is not None
+ assert res_one["extra"] == "yes"
+ # Test find conversion
+ results = table.find(id=100)
+ for rec in results:
+ assert rec["extra"] == "yes"
+ # Test distinct conversion using custom lambda row type
+ distinct = table.distinct("value")
+ for rec in distinct:
+ assert rec["extra"] == "yes"
+ # Test iteration conversion
+ found = False
+ for rec in table:
+ if rec.get("id") == 100:
+ found = True
+ assert rec["extra"] == "yes"
+ assert found
+
+def test_non_callable_row_type():
+ """Test that setting a non-callable row_type raises an error when transforming a record."""
+ db = FakeDB()
+ db.row_type = 5 # non-callable row_type
+ table = FakeTable(db)
+ row = {"id": 1, "value": "error expected"}
+ table.insert(row)
+ import pytest
+ with pytest.raises(TypeError):
+ _ = table.find_one(id=1)
+def test_custom_namedtuple_row_type():
+ """Test using a custom row_type as collections.namedtuple returning a tuple conversion."""
+ from collections import namedtuple
+ CustomRow = namedtuple("CustomRow", ["id", "value"])
+ db = FakeDB()
+ db.row_type = lambda rec: CustomRow(**rec)
+ table = FakeTable(db)
+ row = {"id": 1, "value": "test"}
+ table.insert(row)
+ result = table.find_one(id=1)
+ assert isinstance(result, CustomRow)
+ assert result.id == 1
+ assert result.value == "test"
+
+def test_exception_in_custom_row_type():
+ """Test that exceptions in custom row_type conversion propagate as expected."""
+ def faulty_row_type(record):
+ # Attempt to access a key that might be missing so a KeyError is raised.
+ return record["mandatory"]
+ db = FakeDB()
+ db.row_type = faulty_row_type
+ table = FakeTable(db)
+ row = {"id": 1, "value": "error expected"}
+ table.insert(row)
+ import pytest
+ with pytest.raises(KeyError):
+ _ = table.find_one(id=1)
+def test_distinct_empty_fields():
+ """Test that calling distinct with no field filters returns a single entry (an empty dict) as distinct."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"a": 1},
+ {"b": 2},
+ ]
+ for row in rows:
+ table.insert(row)
+ # When no field is provided, every record yields an empty key tuple so only one item is returned.
+ result = table.distinct()
+ assert len(result) == 1
+ # The distinct entry should be an empty dict converted by Constructor.
+ assert result[0] == {}
+
+def test_find_one_with_limit_keyword_ignored():
+ """Test that passing '_limit' to find_one is ignored so it behaves normally."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ row = {"id": 10, "value": "test"}
+ table.insert(row)
+ # Pass _limit as extra keyword; it should be skipped during filtering.
+ result = table.find_one(id=10, _limit=100)
+ assert result is not None
+ assert result.id == 10
+
+def test_insert_non_dict_record_exception():
+ """Test that inserting a non-dictionary record will raise an error when trying to perform get on it."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ table.insert(["not", "a", "dict"])
+ with pytest.raises(AttributeError):
+ _ = table.find_one(somekey="value")
+
+def test_find_with_none_as_filter_value():
+ """Test that find correctly matches records with a None filter value."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "optional": None},
+ {"id": 2, "optional": "present"},
+ ]
+ for row in rows:
+ table.insert(row)
+ result = table.find(optional=None)
+ assert len(result) == 1
+ assert result[0].id == 1
+def test_switch_row_type():
+ """Test that switching the db.row_type dynamically affects row conversion."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ row = {"id": 99, "name": "First"}
+ table.insert(row)
+ result1 = table.find_one(id=99)
+ # Check that result1 was converted using Constructor
+ assert hasattr(result1, "name")
+ # Now switch the row type to a lambda that adds an extra key
+ db.row_type = lambda rec: {**rec, "extra": "modified"}
+ result2 = table.find_one(id=99)
+ # Check that the new conversion is applied (a plain dict with extra key)
+ assert isinstance(result2, dict)
+ assert result2.get("extra") == "modified"
+
+def test_find_invalid_limit_type():
+ """Test that using a non-integer _limit value causes a TypeError during comparison."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ rows = [
+ {"id": 1, "value": "x"},
+ {"id": 2, "value": "x"},
+ ]
+ for r in rows:
+ table.insert(r)
+ import pytest
+ with pytest.raises(TypeError):
+ _ = table.find(value="x", _limit="non_integer")
+def test_modification_of_converted_record_does_not_affect_underlying_data():
+ """Test that modifying a converted record does not change the underlying record in FakeTable."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ row = {"id": 1, "value": "original"}
+ table.insert(row)
+ converted = table.find_one(id=1)
+ # Modify the converted record
+ converted["value"] = "modified"
+ # Get a new conversion from the underlying record
+ fresh = table.find_one(id=1)
+ assert fresh["value"] == "original", "Underlying record should not be affected by modifications to the converted record"
+
+def test_find_returns_fresh_instances_on_each_call():
+ """Test that each call to find_one returns a fresh conversion instance and not a cached object."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ row = {"id": 2, "value": "constant"}
+ table.insert(row)
+ first_instance = table.find_one(id=2)
+ second_instance = table.find_one(id=2)
+ assert first_instance is not second_instance, "Each conversion should produce a new instance"
+
+def test_find_no_matching_record_returns_none():
+ """Test that find_one returns None when no record matches the provided filter."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ table.insert({"id": 3, "value": "test"})
+ result = table.find_one(id=999)
+ assert result is None, "find_one should return None if no matching record is found"
+def test_custom_constant_row_type():
+ """Test using a constant lambda as the row_type that returns a constant value regardless of record content.
+ This ensures that find_one, iteration, and distinct all return the constant conversion."""
+ db = FakeDB()
+ db.row_type = lambda rec: "constant"
+ table = FakeTable(db)
+ row = {"id": 101, "value": "dummy"}
+ table.insert(row)
+ result = table.find_one(id=101)
+ assert result == "constant", "Expected constant conversion for find_one"
+ iter_results = list(table)
+ # All results via iteration should yield the constant value
+ for item in iter_results:
+ assert item == "constant", "Iteration did not yield constant conversion"
+ distinct_results = table.distinct("value")
+ for item in distinct_results:
+ assert item == "constant", "Distinct did not yield constant conversion"
+
+def test_find_with_list_field():
+ """Test that the find method correctly matches records with list fields using equality.
+ The record should be found only when the list exactly matches."""
+ db = FakeDB()
+ db.row_type = Constructor
+ table = FakeTable(db)
+ row1 = {"id": 201, "tags": ["a", "b", "c"]}
+ row2 = {"id": 202, "tags": ["x", "y"]}
+ table.insert(row1)
+ table.insert(row2)
+ # Search for the record with an exact match on the list field.
+ result = table.find_one(tags=["a", "b", "c"])
+ assert result is not None, "Expected to find record with matching list"
+ assert result.id == 201
+ # A search for a list that does not match should return None.
+ result_none = table.find_one(tags=["non", "matching"])
+ assert result_none is None, "Expected no record with non-matching list"
\ No newline at end of file
diff --git a/tests/test_util.py b/tests/test_util.py
new file mode 100644
index 0000000..31fc981
--- /dev/null
+++ b/tests/test_util.py
@@ -0,0 +1,302 @@
+import pytest
+from collections import namedtuple, OrderedDict
+from urllib.parse import urlparse, parse_qs
+from hashlib import sha1
+from sqlalchemy.exc import ResourceClosedError
+from dataset.util import (
+ convert_row,
+ DatasetException,
+ iter_result_proxy,
+ make_sqlite_url,
+ ResultIter,
+ normalize_column_name,
+ normalize_column_key,
+ normalize_table_name,
+ safe_url,
+ index_name,
+ pad_chunk_columns,
+)
+
+# Fake ResultProxy for testing iter_result_proxy and ResultIter
+class FakeCursor:
+ def __init__(self, rows, keys=None, fetchmany_size=None, raise_on_keys=False):
+ self.rows = rows[:] # copy list of rows
+ self.fetchmany_size = fetchmany_size
+ self.closed = False
+ self._keys = keys if keys is not None else (list(rows[0]._fields) if rows else [])
+ self.raise_on_keys = raise_on_keys
+
+ def fetchall(self):
+ result = self.rows
+ self.rows = []
+ return result
+
+ def fetchmany(self, size):
+ if self.fetchmany_size is not None:
+ size = self.fetchmany_size
+ chunk = self.rows[:size]
+ self.rows = self.rows[size:]
+ return chunk
+
+ def keys(self):
+ if self.raise_on_keys:
+ raise ResourceClosedError("Cursor is closed")
+ return self._keys
+
+ def close(self):
+ self.closed = True
+
+# Define a dummy row type to simulate namedtuple rows for testing convert_row and ResultIter.
+DummyRow = namedtuple("DummyRow", ["a", "b"])
+
+class TestUtil:
+ def test_convert_row_with_valid_row(self):
+ """Test convert_row with a valid row-like object."""
+ row = DummyRow(a=1, b=2)
+ result = convert_row(OrderedDict, row)
+ assert result == OrderedDict([("a", 1), ("b", 2)])
+
+ def test_convert_row_with_none(self):
+ """Test that convert_row returns None when given None."""
+ assert convert_row(OrderedDict, None) is None
+
+ def test_iter_result_proxy_fetchall(self):
+ """Test iter_result_proxy using fetchall method."""
+ rows = [DummyRow(a=i, b=i+1) for i in range(5)]
+ fake_cursor = FakeCursor(rows)
+ result = list(iter_result_proxy(fake_cursor))
+ assert result == rows
+
+ def test_iter_result_proxy_fetchmany(self):
+ """Test iter_result_proxy using fetchmany method."""
+ rows = [DummyRow(a=i, b=i+1) for i in range(5)]
+ fake_cursor = FakeCursor(rows, fetchmany_size=2)
+ result = list(iter_result_proxy(fake_cursor, step=2))
+ assert result == rows
+
+ def test_make_sqlite_url_basic(self):
+ """Test that make_sqlite_url returns a basic URL when no extra params are provided."""
+ url = make_sqlite_url("test.db")
+ assert url == "sqlite:///test.db"
+
+ def test_make_sqlite_url_with_params(self):
+ """Test that make_sqlite_url returns a URL with encoded query parameters."""
+ url = make_sqlite_url("test.db", cache="shared", timeout=30, mode="rw", check_same_thread=False, immutable=True, nolock=True)
+ # The URL should start with 'sqlite:///file:test.db?' due to URI scheme
+ assert url.startswith("sqlite:///file:test.db?")
+ parsed = urlparse(url)
+ qs = parse_qs(parsed.query)
+ expected = {
+ "cache": ["shared"],
+ "timeout": ["30"],
+ "mode": ["rw"],
+ "nolock": ["1"],
+ "immutable": ["1"],
+ "check_same_thread": ["false"],
+ "uri": ["true"]
+ }
+ assert qs == expected
+
+ def test_result_iter_normal_iteration(self):
+ """Test that ResultIter iterates over rows and closes the cursor upon completion."""
+ rows = [DummyRow(a=i, b=i+1) for i in range(3)]
+ fake_cursor = FakeCursor(rows)
+ result_iter = ResultIter(fake_cursor)
+ collected = list(result_iter)
+ expected = [OrderedDict([("a", row.a), ("b", row.b)]) for row in rows]
+ assert collected == expected
+ # After iteration, the cursor should be closed
+ assert fake_cursor.closed is True
+
+ def test_result_iter_closed_cursor(self):
+ """Test that ResultIter handles ResourceClosedError by using an empty iterator."""
+ fake_cursor = FakeCursor([], raise_on_keys=True)
+ result_iter = ResultIter(fake_cursor)
+ result = list(result_iter)
+ assert result == []
+ fake_cursor.close()
+ assert fake_cursor.closed is True
+
+ def test_normalize_column_name_valid(self):
+ """Test normalize_column_name with a valid column name."""
+ valid_name = "valid_column"
+ assert normalize_column_name(valid_name) == valid_name
+
+ def test_normalize_column_name_invalid(self):
+ """Test normalize_column_name with column names that should be rejected."""
+ with pytest.raises(ValueError):
+ normalize_column_name("invalid.column")
+ with pytest.raises(ValueError):
+ normalize_column_name("invalid-column")
+
+ def test_normalize_column_name_length(self):
+ """Test that normalize_column_name truncates column names exceeding 63 bytes."""
+ long_name = "a" * 100
+ result = normalize_column_name(long_name)
+ assert len(result.encode("utf-8")) < 64
+
+ def test_normalize_column_key(self):
+ """Test normalize_column_key converts strings to uppercase and removes spaces."""
+ result = normalize_column_key(" col Name ")
+ assert result == "COLNAME"
+ # Check that None is returned as None
+ assert normalize_column_key(None) is None
+
+ def test_normalize_table_name_valid(self):
+ """Test normalize_table_name with a valid table name."""
+ table_name = "table_name"
+ assert normalize_table_name(table_name) == table_name
+
+ def test_normalize_table_name_invalid(self):
+ """Test normalize_table_name raises ValueError for invalid names."""
+ with pytest.raises(ValueError):
+ normalize_table_name("")
+
+ def test_safe_url(self):
+ """Test that safe_url masks passwords contained in a connection URL."""
+ original_url = "postgres://user:secret@localhost/db"
+ result = safe_url(original_url)
+ assert "secret" not in result
+ assert "*****" in result
+
+ def test_index_name(self):
+ """Test that index_name generates an index name in the expected format."""
+ table = "mytable"
+ columns = ["col1", "col2"]
+ idx = index_name(table, columns)
+ assert idx.startswith("ix_mytable_")
+ key_part = idx[len("ix_mytable_"):]
+ # The key should be exactly 16 hex characters; assert its length and that it is valid hex.
+ assert len(key_part) == 16
+ int(key_part, 16)
+
+ def test_pad_chunk_columns(self):
+ """Test that pad_chunk_columns adds missing columns with None values."""
+ chunk = [{"a": 1}, {"b": 2}]
+ columns = ["a", "b", "c"]
+ result = pad_chunk_columns(chunk, columns)
+ for record in result:
+ for col in columns:
+ assert col in record
+ # Ensure that existing keys are retained
+ assert result[0]["a"] == 1
+
+ def test_dataset_exception(self):
+ """Test that DatasetException carries the correct message."""
+ exc = DatasetException("Test error")
+ assert str(exc) == "Test error"
+
+ def test_make_sqlite_url_invalid_cache(self):
+ """Test that make_sqlite_url raises AssertionError for an invalid cache value."""
+ with pytest.raises(AssertionError):
+ make_sqlite_url("test.db", cache="invalid")
+
+ def test_make_sqlite_url_invalid_mode(self):
+ """Test that make_sqlite_url raises AssertionError for an invalid mode value."""
+ with pytest.raises(AssertionError):
+ make_sqlite_url("test.db", mode="invalid")
+
+ def test_normalize_column_name_non_string(self):
+ """Test that normalize_column_name raises ValueError when given a non-string."""
+ with pytest.raises(ValueError):
+ normalize_column_name(123)
+
+ def test_normalize_table_name_non_string(self):
+ """Test that normalize_table_name raises ValueError when given a non-string."""
+ with pytest.raises(ValueError):
+ normalize_table_name(123)
+
+ def test_normalize_column_key_non_string(self):
+ """Test that normalize_column_key returns None when given a non-string."""
+ assert normalize_column_key(123) is None
+
+ def test_iter_result_proxy_with_zero(self):
+ """Test iter_result_proxy with step=0 returns an empty iterator even if there are rows."""
+ rows = [DummyRow(a=i, b=i+1) for i in range(3)]
+ fake_cursor = FakeCursor(rows, fetchmany_size=0)
+ result = list(iter_result_proxy(fake_cursor, step=0))
+ assert result == []
+ def test_safe_url_no_password(self):
+ """Test that safe_url returns the URL unchanged when there is no password."""
+ url = "postgres://user@localhost/db"
+ result = safe_url(url)
+ assert result == url
+
+ def test_convert_row_custom_row_type(self):
+ """Test convert_row with a custom row_type (dict) instead of OrderedDict."""
+ row = DummyRow(a=10, b=20)
+ result = convert_row(dict, row)
+ expected = dict([("a", 10), ("b", 20)])
+ assert result == expected
+
+ def test_pad_chunk_columns_empty_chunk(self):
+ """Test that pad_chunk_columns returns an empty list when given an empty chunk."""
+ chunk = []
+ columns = ["a", "b"]
+ result = pad_chunk_columns(chunk, columns)
+ assert result == []
+
+ def test_pad_chunk_columns_all_present(self):
+ """Test that pad_chunk_columns does not modify the chunk if all columns are present."""
+ chunk = [{"a": 1, "b": 2}]
+ columns = ["a", "b"]
+ result = pad_chunk_columns(chunk, columns)
+ assert result == chunk
+
+ def test_result_iter_next_after_exhaustion(self):
+ """Test that ResultIter raises StopIteration after all rows have been iterated."""
+ rows = [DummyRow(a=1, b=2)]
+ fake_cursor = FakeCursor(rows)
+ result_iter = ResultIter(fake_cursor)
+ # Get the only element.
+ first = next(result_iter)
+ assert first == OrderedDict([("a", 1), ("b", 2)])
+ # Next call should raise StopIteration and close the cursor.
+ with pytest.raises(StopIteration):
+ next(result_iter)
+ # Ensure cursor is closed.
+ def test_iter_result_proxy_empty(self):
+ """Test iter_result_proxy returns an empty list when given an empty cursor."""
+ fake_cursor = FakeCursor([])
+ result = list(iter_result_proxy(fake_cursor))
+ assert result == []
+
+ def test_iter_result_proxy_negative_step(self):
+ """Test iter_result_proxy with a negative step value produces expected slice behavior."""
+ # Create 3 dummy rows. With step = -1, fetchmany will return rows[:-1] on the first call.
+ rows = [DummyRow(a=i, b=i+10) for i in range(3)]
+ fake_cursor = FakeCursor(rows)
+ result = list(iter_result_proxy(fake_cursor, step=-1))
+ # When step is negative, list slicing returns all but the last element.
+ expected = rows[:-1]
+ assert result == expected
+
+ def test_normalize_column_name_trimmed(self):
+ """Test normalize_column_name trims whitespace from a valid column name."""
+ input_name = " trimmed_column "
+ result = normalize_column_name(input_name)
+ # The result should be trimmed of whitespace and still be valid.
+ assert result == "trimmed_column"
+
+ def test_normalize_table_name_trimmed(self):
+ """Test normalize_table_name trims whitespace from a valid table name."""
+ input_name = " trimmed_table "
+ result = normalize_table_name(input_name)
+ assert result == "trimmed_table"
+
+ def test_result_iter_iter_method(self):
+ """Test that the __iter__ method of ResultIter returns self."""
+ rows = [DummyRow(a=99, b=100)]
+ fake_cursor = FakeCursor(rows)
+ result_iter = ResultIter(fake_cursor)
+ # Calling iter() on result_iter should return itself.
+ # Verify that __iter__ returns self without closing the cursor
+ it = iter(result_iter)
+ assert it is result_iter
+ # Before iteration the cursor should still be open
+ assert not fake_cursor.closed
+ # Fully consume the iterator so that the cursor gets closed
+ list(result_iter)
+ # After iteration, the cursor should be closed
+ assert fake_cursor.closed
+# End of TestUtil class
\ No newline at end of file