diff --git a/codebeaver.yml b/codebeaver.yml
new file mode 100644
index 0000000..419e243
--- /dev/null
+++ b/codebeaver.yml
@@ -0,0 +1,2 @@
+from: pytest
+# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
\ No newline at end of file
diff --git a/test/test_database.py b/test/test_database.py
new file mode 100644
index 0000000..5f39c5e
--- /dev/null
+++ b/test/test_database.py
@@ -0,0 +1,254 @@
+import os
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from sqlalchemy.exc import IntegrityError, SQLAlchemyError
+from sqlalchemy.sql import text
+
+from dataset import connect
+
+from .conftest import TEST_DATA
+
+
+def test_valid_database_url(db):
+ assert db.url, os.environ["DATABASE_URL"]
+
+
+def test_database_url_query_string(db):
+ db = connect("sqlite:///:memory:/?cached_statements=1")
+ assert "cached_statements" in db.url, db.url
+
+
+def test_tables(db, table):
+ assert db.tables == ["weather"], db.tables
+
+
+def test_contains(db, table):
+ assert "weather" in db, db.tables
+
+
+def test_create_table(db):
+ table = db["foo"]
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert "id" in table.table.c, table.table.c
+
+
+def test_create_table_no_ids(db):
+ if db.is_mysql or db.is_sqlite:
+ return
+ table = db.create_table("foo_no_id", primary_id=False)
+ assert table.table.name == "foo_no_id"
+ assert len(table.table.columns) == 0, table.table.columns
+
+
+def test_create_table_custom_id1(db):
+ pid = "string_id"
+ table = db.create_table("foo2", pid, db.types.string(255))
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+ table.insert({pid: "foobar"})
+ assert table.find_one(string_id="foobar")[pid] == "foobar"
+
+
+def test_create_table_custom_id2(db):
+ pid = "string_id"
+ table = db.create_table("foo3", pid, db.types.string(50))
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({pid: "foobar"})
+ assert table.find_one(string_id="foobar")[pid] == "foobar"
+
+
+def test_create_table_custom_id3(db):
+ pid = "int_id"
+ table = db.create_table("foo4", primary_id=pid)
+ assert db.has_table(table.table.name)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({pid: 123})
+ table.insert({pid: 124})
+ assert table.find_one(int_id=123)[pid] == 123
+ assert table.find_one(int_id=124)[pid] == 124
+ with pytest.raises(IntegrityError):
+ table.insert({pid: 123})
+ db.rollback()
+
+
+def test_create_table_shorthand1(db):
+ pid = "int_id"
+ table = db.get_table("foo5", pid)
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({"int_id": 123})
+ table.insert({"int_id": 124})
+ assert table.find_one(int_id=123)["int_id"] == 123
+ assert table.find_one(int_id=124)["int_id"] == 124
+ with pytest.raises(IntegrityError):
+ table.insert({"int_id": 123})
+
+
+def test_create_table_shorthand2(db):
+ pid = "string_id"
+ table = db.get_table("foo6", primary_id=pid, primary_type=db.types.string(255))
+ assert len(table.table.columns) == 1, table.table.columns
+ assert pid in table.table.c, table.table.c
+
+ table.insert({"string_id": "foobar"})
+ assert table.find_one(string_id="foobar")["string_id"] == "foobar"
+
+
+def test_with(db, table):
+ init_length = len(table)
+ with pytest.raises(ValueError):
+ with db:
+ table.insert(
+ {
+ "date": datetime(2011, 1, 1),
+ "temperature": 1,
+ "place": "tmp_place",
+ }
+ )
+ raise ValueError()
+ db.rollback()
+ assert len(table) == init_length
+
+
+def test_invalid_values(db, table):
+ if db.is_mysql:
+ # WARNING: mysql seems to be doing some weird type casting
+ # upon insert. The mysql-python driver is not affected but
+ # it isn't compatible with Python 3
+ # Conclusion: use postgresql.
+ return
+ with pytest.raises(SQLAlchemyError):
+ table.insert({"date": True, "temperature": "wrong_value", "place": "tmp_place"})
+
+
+def test_load_table(db, table):
+ tbl = db.load_table("weather")
+ assert tbl.table.name == table.table.name
+
+
+def test_query(db, table):
+ r = db.query("SELECT COUNT(*) AS num FROM weather").next()
+ assert r["num"] == len(TEST_DATA), r
+
+
+def test_table_cache_updates(db):
+ tbl1 = db.get_table("people")
+ data = OrderedDict([("first_name", "John"), ("last_name", "Smith")])
+ tbl1.insert(data)
+ data["id"] = 1
+ tbl2 = db.get_table("people")
+ assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data)
+
+def test_repr(db):
+ """Test __repr__ returns the correct database URL in its string representation."""
+ rep = repr(db)
+ assert db.url in rep
+
+def test_metadata(db):
+ """Test that the metadata property returns a proper MetaData instance with the correct schema."""
+ meta = db.metadata
+ from sqlalchemy.schema import MetaData
+ assert isinstance(meta, MetaData)
+ assert meta.schema == db.schema
+
+def test_in_transaction(db, table):
+ """Test that the in_transaction property accurately reflects the transaction state."""
+ # Upon accessing db.conn, a transaction is begun automatically
+ if db.is_sqlite:
+ # SQLite may not begin a transaction automatically so we start one explicitly
+ db.begin()
+ assert db.in_transaction is True
+ else:
+ assert db.in_transaction is True
+ db.commit()
+ assert db.in_transaction is False
+ with db:
+ assert db.in_transaction is True
+ assert db.in_transaction is False
+
+def test_close_connection(db):
+ # Close the database and override close() so that fixture teardown does not call close() again
+ db.close()
+ db.close = lambda: None
+ db.close()
+ import pytest
+ with pytest.raises(Exception):
+ _ = db.conn
+
+def test_query_step(db, table):
+ """Test that query() with _step parameter of 0 returns all rows (non-stepped result)."""
+ table.insert(dict(a=1))
+ table.insert(dict(a=2))
+ rows = list(db.query("SELECT a FROM %s" % table.table.name, _step=0))
+ assert len(rows) >= 2
+
+def test_contains_invalid(db):
+ """Test that __contains__ returns False for a non-existing table."""
+ assert "non_existing_table" not in db
+
+def test_on_connect_statements():
+ """Test that a database created with custom on_connect_statements is properly configured."""
+ custom_sql = "PRAGMA synchronous=OFF"
+ from dataset import connect
+ db2 = connect("sqlite:///:memory:/", on_connect_statements=[custom_sql])
+ assert db2.is_sqlite
+ db2.close()
+
+def test_transaction_rollback(db, table):
+ """Test that rollback properly undoes inserted records in an uncommitted transaction."""
+ initial_count = len(list(table.all()))
+ table.insert(dict(a=999))
+ db.rollback()
+ new_count = len(list(table.all()))
+ assert new_count == initial_count
+def test_op_property(db):
+ """Test that the op property returns an instance of alembic.operations.Operations."""
+ from alembic.operations import Operations
+ op = db.op
+ assert isinstance(op, Operations)
+ # Rollback any active transaction to clean up
+ db.rollback()
+
+def test_views(db):
+ """Test that the views property correctly returns a newly created view."""
+ import pytest
+ # Only run this test for sqlite or postgres as creating views might differ between dialects
+ if not (db.is_sqlite or db.is_postgres):
+ pytest.skip("Views test only valid for SQLite and PostgreSQL")
+ # Create a view using a raw SQL statement
+ db.conn.execute(text("CREATE VIEW test_view AS SELECT 1 AS col"))
+ db.commit()
+ assert "test_view" in db.views, "The created view 'test_view' was not found in db.views"
+ # Clean up: drop the view (use a try/except in case dropping fails)
+ try:
+ db.conn.execute(text("DROP VIEW test_view"))
+ except Exception:
+ pass
+ db.commit()
+
+def test_ipython_completions(db):
+ """Test that _ipython_key_completions_ returns a list equal to the tables property."""
+ completions = db._ipython_key_completions_()
+ assert isinstance(completions, list)
+ assert set(completions) == set(db.tables)
+
+def test_commit_empty_transaction(db):
+ """Test that commit works properly for an empty transaction."""
+ db.begin()
+ db.commit()
+ assert not db.in_transaction, "After commit, there should be no active transaction."
+
+def test_connection_reuse(db):
+ """Test that repeated calls to the db.conn property return the same connection instance."""
+ conn1 = db.conn
+ conn2 = db.conn
+ assert conn1 is conn2, "db.conn did not return the same connection instance on repeated calls."
\ No newline at end of file
diff --git a/test/test_table.py b/test/test_table.py
new file mode 100644
index 0000000..4d91101
--- /dev/null
+++ b/test/test_table.py
@@ -0,0 +1,576 @@
+import os
+import pytest
+from datetime import datetime
+from collections import OrderedDict
+from sqlalchemy.types import BIGINT, INTEGER, VARCHAR, TEXT
+from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
+
+from dataset import chunked
+
+from .conftest import TEST_DATA, TEST_CITY_1
+
+
+def test_insert(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ last_id = table.insert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_insert_ignore(table):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_ignore_all_key(table):
+ for i in range(0, 4):
+ table.insert_ignore(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_insert_json(table):
+ last_id = table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "place": "Berlin",
+ "info": {
+ "currency": "EUR",
+ "language": "German",
+ "population": 3292365,
+ },
+ }
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ assert table.find_one(id=last_id)["place"] == "Berlin"
+
+
+def test_upsert(table):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_single_column(db):
+ table = db["banana_single_col"]
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+ table.upsert({"color": "Yellow"}, ["color"])
+ assert len(table) == 1, len(table)
+
+
+def test_upsert_all_key(table):
+ assert len(table) == len(TEST_DATA), len(table)
+ for i in range(0, 2):
+ table.upsert(
+ {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
+ ["date", "temperature", "place"],
+ )
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+
+
+def test_upsert_id(db):
+ table = db["banana_with_id"]
+ data = dict(id=10, title="I am a banana!")
+ table.upsert(data, ["id"])
+ assert len(table) == 1, len(table)
+
+
+def test_update_while_iter(table):
+ for row in table:
+ row["foo"] = "bar"
+ table.update(row, ["place", "date"])
+ assert len(table) == len(TEST_DATA), len(table)
+
+
+def test_weird_column_names(table):
+ with pytest.raises(ValueError):
+ table.insert(
+ {
+ "date": datetime(2011, 1, 2),
+ "temperature": -10,
+ "foo.bar": "Berlin",
+ "qux.bar": "Huhu",
+ }
+ )
+
+
+def test_cased_column_names(db):
+ tbl = db["cased_column_names"]
+ tbl.insert({"place": "Berlin"})
+ tbl.insert({"Place": "Berlin"})
+ tbl.insert({"PLACE ": "Berlin"})
+ assert len(tbl.columns) == 2, tbl.columns
+ assert len(list(tbl.find(Place="Berlin"))) == 3
+ assert len(list(tbl.find(place="Berlin"))) == 3
+ assert len(list(tbl.find(PLACE="Berlin"))) == 3
+
+
+def test_invalid_column_names(db):
+ tbl = db["weather"]
+ with pytest.raises(ValueError):
+ tbl.insert({None: "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"": "banana"})
+
+ with pytest.raises(ValueError):
+ tbl.insert({"-": "banana"})
+
+
+def test_delete(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ original_count = len(table)
+ assert len(table) == len(TEST_DATA) + 1, len(table)
+ # Test bad use of API
+ with pytest.raises(ArgumentError):
+ table.delete({"place": "Berlin"})
+ assert len(table) == original_count, len(table)
+
+ assert table.delete(place="Berlin") is True, "should return 1"
+ assert len(table) == len(TEST_DATA), len(table)
+ assert table.delete() is True, "should return non zero"
+ assert len(table) == 0, len(table)
+
+
+def test_repr(table):
+ assert (
+ repr(table) == "
"
+ ), "the representation should be "
+
+
+def test_delete_nonexist_entry(table):
+ assert (
+ table.delete(place="Berlin") is False
+ ), "entry not exist, should fail to delete"
+
+
+def test_find_one(table):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"})
+ d = table.find_one(place="Berlin")
+ assert d["temperature"] == -10, d
+ d = table.find_one(place="Atlantis")
+ assert d is None, d
+
+
+def test_count(table):
+ assert len(table) == 6, len(table)
+ length = table.count(place=TEST_CITY_1)
+ assert length == 3, length
+
+
+def test_find(table):
+ ds = list(table.find(place=TEST_CITY_1))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _step=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=1, _step=2))
+ assert len(ds) == 1, ds
+ ds = list(table.find(_step=2))
+ assert len(ds) == len(TEST_DATA), ds
+ ds = list(table.find(order_by=["temperature"]))
+ assert ds[0]["temperature"] == -1, ds
+ ds = list(table.find(order_by=["-temperature"]))
+ assert ds[0]["temperature"] == 8, ds
+ ds = list(table.find(table.table.columns.temperature > 4))
+ assert len(ds) == 3, ds
+
+
+def test_find_dsl(table):
+ ds = list(table.find(place={"like": "%lw%"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={">": 5}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={">=": 5}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(temperature={"<": 0}))
+ assert len(ds) == 1, ds
+ ds = list(table.find(temperature={"<=": 0}))
+ assert len(ds) == 2, ds
+ ds = list(table.find(temperature={"!=": -1}))
+ assert len(ds) == 5, ds
+ ds = list(table.find(temperature={"between": [5, 8]}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"=": "G€lway"}))
+ assert len(ds) == 3, ds
+ ds = list(table.find(place={"ilike": "%LwAy"}))
+ assert len(ds) == 3, ds
+
+
+def test_offset(table):
+ ds = list(table.find(place=TEST_CITY_1, _offset=1))
+ assert len(ds) == 2, ds
+ ds = list(table.find(place=TEST_CITY_1, _limit=2, _offset=2))
+ assert len(ds) == 1, ds
+
+
+def test_streamed_update(table):
+ ds = list(table.find(place=TEST_CITY_1, _streamed=True, _step=1))
+ assert len(ds) == 3, len(ds)
+ for row in table.find(place=TEST_CITY_1, _streamed=True, _step=1):
+ row["temperature"] = -1
+ table.update(row, ["id"])
+
+
+def test_distinct(table):
+ x = list(table.distinct("place"))
+ assert len(x) == 2, x
+ x = list(table.distinct("place", "date"))
+ assert len(x) == 6, x
+ x = list(
+ table.distinct(
+ "place",
+ "date",
+ table.table.columns.date >= datetime(2011, 1, 2, 0, 0),
+ )
+ )
+ assert len(x) == 4, x
+
+ x = list(table.distinct("temperature", place="B€rkeley"))
+ assert len(x) == 3, x
+ x = list(table.distinct("temperature", place=["B€rkeley", "G€lway"]))
+ assert len(x) == 6, x
+
+
+def test_insert_many(table):
+ data = TEST_DATA * 100
+ table.insert_many(data, chunk_size=13)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert(table):
+ data = TEST_DATA * 100
+ with chunked.ChunkedInsert(table) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(table) == len(data) + 6, (len(table), len(data))
+
+
+def test_chunked_insert_callback(table):
+ data = TEST_DATA * 100
+ N = 0
+
+ def callback(queue):
+ nonlocal N
+ N += len(queue)
+
+ with chunked.ChunkedInsert(table, callback=callback) as chunk_tbl:
+ for item in data:
+ chunk_tbl.insert(item)
+ assert len(data) == N
+ assert len(table) == len(data) + 6
+
+
+def test_update_many(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)])
+ tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id")
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"]
+
+
+def test_chunked_update(db):
+ tbl = db["update_many_test"]
+ tbl.insert_many(
+ [
+ dict(temp=10, location="asdf"),
+ dict(temp=20, location="qwer"),
+ dict(temp=30, location="asdf"),
+ ]
+ )
+ db.commit()
+
+ chunked_tbl = chunked.ChunkedUpdate(tbl, ["id"])
+ chunked_tbl.update(dict(id=1, temp=50))
+ chunked_tbl.update(dict(id=2, location="asdf"))
+ chunked_tbl.update(dict(id=3, temp=50))
+ chunked_tbl.flush()
+ db.commit()
+
+ # Ensure data has been updated.
+ assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50
+ assert tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf"
+
+
+def test_upsert_many(db):
+ # Also tests updating on records with different attributes
+ tbl = db["upsert_many_test"]
+
+ W = 100
+ tbl.upsert_many([dict(age=10), dict(weight=W)], "id")
+ assert tbl.find_one(id=1)["age"] == 10
+
+ tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id")
+ assert tbl.find_one(id=2)["weight"] == W / 2
+
+
+def test_drop_operations(table):
+ assert table._table is not None, "table shouldn't be dropped yet"
+ table.drop()
+ assert table._table is None, "table should be dropped now"
+ assert list(table.all()) == [], table.all()
+ assert table.count() == 0, table.count()
+
+
+def test_table_drop(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+
+
+def test_table_drop_then_create(db, table):
+ assert "weather" in db
+ db["weather"].drop()
+ assert "weather" not in db
+ db["weather"].insert({"foo": "bar"})
+
+
+def test_columns(table):
+ cols = table.columns
+ assert len(list(cols)) == 4, "column count mismatch"
+ assert "date" in cols and "temperature" in cols and "place" in cols
+
+
+def test_drop_column(table):
+ try:
+ table.drop_column("date")
+ assert "date" not in table.columns
+ except RuntimeError:
+ pass
+
+
+def test_iter(table):
+ c = 0
+ for row in table:
+ c += 1
+ assert c == len(table)
+
+
+def test_update(table):
+ date = datetime(2011, 1, 2)
+ res = table.update(
+ {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"]
+ )
+ assert res, "update should return True"
+ m = table.find_one(place=TEST_CITY_1, date=date)
+ assert m["temperature"] == -10, (
+ "new temp. should be -10 but is %d" % m["temperature"]
+ )
+
+
+def test_create_column(db, table):
+ flt = db.types.float
+ table.create_column("foo", flt)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["foo"].type
+ assert "foo" in table.columns, table.columns
+
+
+def test_ensure_column(db, table):
+ flt = db.types.float
+ table.create_column_by_example("foo", 0.1)
+ assert "foo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["foo"].type, flt), table.table.c["bar"].type
+ table.create_column_by_example("bar", 1)
+ assert "bar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bar"].type, BIGINT), table.table.c["bar"].type
+ table.create_column_by_example("pippo", "test")
+ assert "pippo" in table.table.c, table.table.c
+ assert isinstance(table.table.c["pippo"].type, TEXT), table.table.c["pippo"].type
+ table.create_column_by_example("bigbar", 11111111111)
+ assert "bigbar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["bigbar"].type, BIGINT), table.table.c[
+ "bigbar"
+ ].type
+ table.create_column_by_example("littlebar", -11111111111)
+ assert "littlebar" in table.table.c, table.table.c
+ assert isinstance(table.table.c["littlebar"].type, BIGINT), table.table.c[
+ "littlebar"
+ ].type
+
+
+def test_key_order(db, table):
+ res = db.query("SELECT temperature, place FROM weather LIMIT 1")
+ keys = list(res.next().keys())
+ assert keys[0] == "temperature"
+ assert keys[1] == "place"
+
+
+def test_empty_query(table):
+ empty = list(table.find(place="not in data"))
+ assert len(empty) == 0, empty
+
+def test_exists_property(db):
+ """Test the exists property of a table before and after creation and drop."""
+ # Create a new temporary table instance with a unique name.
+ temp = db["temp_exists"]
+ # Before any operations, the table should not exist.
+ assert not temp.exists, "Table should not exist yet"
+ # Insert a row to trigger auto table creation (ensure=True forces schema creation)
+ temp.insert({"dummy": "data"}, ensure=True)
+ assert temp.exists, "Table should exist after insertion"
+ # Now drop the table and check exists property again.
+ temp.drop()
+ assert not temp.exists, "Table should not exist after drop"
+
+def test_insert_no_ensure(table, db):
+ """Test that inserting a row with an unknown column using ensure=False omits that column."""
+ # Create a dedicated table to avoid schema conflict with the main fixture.
+ test_tbl = db["no_ensure_test"]
+ # Insert a row with an extra column; since ensure is False it should not be created.
+ inserted = test_tbl.insert({"date": datetime(2011, 1, 2), "temperature": 25, "place": "Testville", "extra": "value"}, ensure=False)
+ new_row = test_tbl.find_one(id=inserted)
+ # The extra column should not appear in the stored row.
+ assert "extra" not in new_row, "Extra column should not be created when ensure is False"
+
+def test_create_index_non_existing(table):
+ """Test creating an index with a non-existent column does nothing."""
+ # Create a column that is known to exist.
+ table.create_column("existing", VARCHAR(50))
+ # Attempt to create an index on one existing and one non-existing column.
+ table.create_index(["existing", "non_existing"], name="idx_test")
+ # Since "non_existing" does not exist, the index should not be created.
+ assert not table.has_index(["existing", "non_existing"]), "Index should not be created with non-existent column"
+
+def test_update_no_op(table):
+ """Test updating a row with only key fields returns the count without making alterations."""
+ # Insert a test row.
+ row = {"date": datetime(2011, 1, 2), "temperature": 15, "place": "Nowhere"}
+ inserted = table.insert(row, ensure=True)
+ # Invoke update with no additional update fields (only key values provided).
+ count = table.update({
+ "date": row["date"],
+ "temperature": row["temperature"],
+ "place": row["place"]
+ }, ["date", "temperature", "place"])
+ # Should return the count (the number of matched rows); we’re not altering any value.
+ assert count >= 1, "Update with no changes should return the number of matching rows"
+
+def test_chunked_insert_edge_case(table):
+ """Test that a chunked insert with empty data does not modify the table."""
+ initial_len = len(table)
+ with chunked.ChunkedInsert(table) as ci:
+ # Do not insert any data.
+ pass
+ assert len(table) == initial_len, "Empty chunked insert should not change table count"
+
+def test_stream_results(table):
+ """Test that streamed query returns the expected number of rows."""
+ # Insert a few rows with a distinct place.
+ for i in range(3):
+ table.insert({"date": datetime(2011, 1, 2), "temperature": 20+i, "place": "Streamville"}, ensure=True)
+ rows = list(table.find(place="Streamville", _streamed=True, _step=1))
+ assert len(rows) == 3, "Streamed query should return exactly three rows"
+
+# End of inserted tests.
+def test_unknown_operator_filter(table):
+ """Test that using an unknown operator yields no results."""
+ # Insert a row with a distinctive place to later filter on.
+ table.insert({"date": datetime(2021, 1, 1), "temperature": 25, "place": "UnknownOp"}, ensure=True)
+ # Use an unknown operator 'foobar' – _generate_clause will return false() so no row should match.
+ results = list(table.find(place={"foobar": "UnknownOp"}))
+ assert len(results) == 0, "Expected no results with an unknown operator filter"
+
+def test_order_by_invalid_column(table):
+ """Test that ordering by a non-existing column does not break the query."""
+ # Order by a non-existent column should simply be ignored.
+ results = list(table.find(order_by=["nonexistent"]))
+ # Expect all rows to be returned.
+ assert len(results) == len(table), "Expected all rows when ordering by an invalid column"
+
+def test_create_duplicate_column(db, table):
+ """Test that creating a duplicate column does not alter the table schema."""
+ initial_columns = set(table.columns)
+ # Create a new column 'duplicate_test'
+ table.create_column("duplicate_test", INTEGER)
+ # Calling create_column again should not add an extra column.
+ table.create_column("duplicate_test", INTEGER)
+ new_columns = set(table.columns)
+ # The new set should have exactly one more column than the initial set.
+ assert "duplicate_test" in new_columns, "Column 'duplicate_test' should exist in the schema"
+ assert len(new_columns) == len(initial_columns) + 1, "Duplicate column creation should not add extra columns"
+
+def test_update_keys_removal(table):
+ """Test that keys used for matching are removed from the update payload."""
+ # Insert a row with an extra field.
+ row_data = {"date": datetime(2022, 2, 2), "temperature": 15, "place": "KeyTest", "extra": "initial"}
+ rid = table.insert(row_data.copy(), ensure=True)
+ # Update using keys (date and place) while also changing the 'extra' value.
+ update_data = {"date": row_data["date"], "place": row_data["place"], "extra": "updated"}
+ count = table.update(update_data, ["date", "place"])
+ assert count >= 1, "Update should affect at least one row"
+ updated_row = table.find_one(id=rid)
+ assert updated_row.get("extra") == "updated", "Update did not modify 'extra' field correctly"
+def test_update_many_empty(db):
+ """
+ Test that update_many with an empty list of rows does not alter the table.
+ """
+ tbl = db["empty_update_test"]
+ # Insert a couple of rows into the table.
+ tbl.insert_many([{"temp": 10}, {"temp": 20}])
+ initial_count = tbl.count()
+ # Call update_many with an empty list.
+ tbl.update_many([], "id")
+ assert tbl.count() == initial_count, "update_many with empty list should not change table rows"
+
+def test_upsert_many_empty(db):
+ """
+ Test that upsert_many with an empty list does not alter the table.
+ """
+ tbl = db["empty_upsert_test"]
+ tbl.insert({"name": "Alice"}, ensure=True)
+ initial_count = tbl.count()
+ tbl.upsert_many([], "id")
+ assert tbl.count() == initial_count, "upsert_many with empty list should not change table rows"
+
+def test_order_by_mixed(table):
+ """
+ Test ordering with a mix of valid and invalid columns.
+ Invalid ordering columns should be ignored.
+ """
+ # Add test rows with specific temperatures.
+ table.insert({"date": datetime(2022, 1, 1), "temperature": 15, "place": "X"})
+ table.insert({"date": datetime(2022, 1, 2), "temperature": 10, "place": "Y"})
+ table.insert({"date": datetime(2022, 1, 3), "temperature": 20, "place": "Z"})
+ # Order by valid temperature ascending; an invalid column "nonexistent" should be ignored.
+ results = list(table.find(order_by=["temperature", "nonexistent"]))
+ temps = [row["temperature"] for row in results]
+ assert temps == sorted(temps), "Rows should be ordered by temperature ascending"
+
+ # Order by descending temperature using valid flag; again invalid should be ignored.
+ results_desc = list(table.find(order_by=["-temperature", "nonexistent"]))
+ temps_desc = [row["temperature"] for row in results_desc]
+ assert temps_desc == sorted(temps_desc, reverse=True), "Rows should be ordered by temperature descending"
+
+def test_insert_boolean_and_null(table):
+ """
+ Test inserting rows with boolean and None values.
+ """
+ # Insert a row with a None temperature and a boolean value in 'active'.
+ inserted_id = table.insert({"date": datetime(2020, 5, 5), "temperature": None, "place": "TestVille", "active": True}, ensure=True)
+ row = table.find_one(id=inserted_id)
+ assert row["temperature"] is None, "Temperature should be None"
+ assert row["active"] is True, "Active should be True"
\ No newline at end of file
diff --git a/tests/test_setup.py b/tests/test_setup.py
new file mode 100644
index 0000000..29068e6
--- /dev/null
+++ b/tests/test_setup.py
@@ -0,0 +1,181 @@
+import importlib
+import pytest
+import setuptools
+
+class TestSetup:
+ """Test suite for verifying setup.py configuration."""
+
+ def test_setup_configuration(self, monkeypatch, tmp_path):
+ """Test that setup() is called with correct metadata when README.md exists."""
+ # Create a dummy README.md file with test content.
+ readme = tmp_path / "README.md"
+ test_readme = "This is a dummy README for testing."
+ readme.write_text(test_readme)
+
+ # Change working directory to tmp_path to pick up the dummy README.
+ monkeypatch.chdir(tmp_path)
+
+ # Dummy function to capture the arguments passed to setuptools.setup.
+ captured = {}
+ def dummy_setup(**kwargs):
+ captured.update(kwargs)
+
+ # Monkey patch setuptools.setup with the dummy_setup.
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+
+ # Remove setup module from sys.modules to force re-execution of top-level code.
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+
+ # Load the setup module to trigger the setup() call.
+ import setup
+ importlib.reload(setup)
+
+ # Verify that the captured setup arguments contain expected metadata.
+ assert captured.get("name") == "dataset"
+ assert captured.get("version") == "1.6.0"
+ assert captured.get("long_description") == test_readme
+ assert "sqlalchemy >= 2.0.15, < 3.0.0" in captured.get("install_requires", [])
+
+ def test_missing_readme(self, monkeypatch, tmp_path):
+ """Test that missing README.md file raises FileNotFoundError during module import."""
+ # Change working directory to tmp_path which does not contain a README.md.
+ monkeypatch.chdir(tmp_path)
+
+ # Remove setup module from sys.modules to force re-execution of top-level code.
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+
+ # Verify that loading setup.py without README.md raises FileNotFoundError.
+ with pytest.raises(FileNotFoundError):
+ import setup
+ importlib.reload(setup)
+ def test_setup_complete_metadata(self, monkeypatch, tmp_path):
+ """Test that setup() is called with complete metadata from setup.py."""
+ # Create a dummy README.md file with test content for complete metadata test.
+ readme = tmp_path / "README.md"
+ test_readme = "Complete metadata test."
+ readme.write_text(test_readme)
+
+ # Change working directory to tmp_path so that setup.py picks up our dummy README.
+ monkeypatch.chdir(tmp_path)
+
+ # Dummy function to capture the arguments passed to setuptools.setup.
+ captured = {}
+ def dummy_setup(**kwargs):
+ captured.update(kwargs)
+
+ # Monkey patch setuptools.setup with the dummy_setup.
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+
+ # Remove setup module from sys.modules to force re-execution of top-level code.
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+
+ # Load the setup module to trigger the setup() call.
+ import setup
+ importlib.reload(setup)
+
+ # Verify that all of the expected metadata keys and values are present in captured setup()
+ assert captured.get("name") == "dataset"
+ assert captured.get("version") == "1.6.0"
+ assert captured.get("description") == "Toolkit for Python-based database access."
+ assert captured.get("long_description") == test_readme
+ assert isinstance(captured.get("classifiers"), list) and len(captured.get("classifiers")) > 0
+ assert captured.get("keywords") == "sql sqlalchemy etl loading utility"
+ assert captured.get("author") == "Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer"
+ assert captured.get("author_email") == "friedrich.lindenberg@gmail.com"
+ assert captured.get("url") == "http://github.com/pudo/dataset"
+ assert captured.get("license") == "MIT"
+ assert isinstance(captured.get("install_requires"), list)
+ assert "sqlalchemy >= 2.0.15, < 3.0.0" in captured.get("install_requires")
+ assert captured.get("extras_require") == {
+ "dev": [
+ "pip",
+ "pytest",
+ "wheel",
+ "flake8",
+ "coverage",
+ "psycopg2-binary",
+ "PyMySQL",
+ "cryptography",
+ ]
+ }
+ assert captured.get("tests_require") == ["pytest"]
+ assert captured.get("test_suite") == "test"
+ # packages is generated via find_packages; we check that a value is present.
+ assert captured.get("packages") is not None
+ assert captured.get("namespace_packages") == []
+ assert captured.get("include_package_data") is False
+ assert captured.get("zip_safe") is False
+ def test_find_packages_called(self, monkeypatch, tmp_path):
+ """Test that the packages parameter is computed using monkeypatched find_packages."""
+ readme = tmp_path / "README.md"
+ readme.write_text("dummy")
+ monkeypatch.chdir(tmp_path)
+ captured = {}
+ def dummy_setup(**kwargs):
+ captured.update(kwargs)
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+ monkeypatch.setattr(setuptools, "find_packages", lambda exclude: ["dummy_pkg"])
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+ import setup
+ importlib.reload(setup)
+ assert captured.get("packages") == ["dummy_pkg"]
+
+ def test_install_requires_length(self, monkeypatch, tmp_path):
+ """Test that install_requires contains exactly three dependencies with correct content."""
+ readme = tmp_path / "README.md"
+ readme.write_text("dummy")
+ monkeypatch.chdir(tmp_path)
+ captured = {}
+ def dummy_setup(**kwargs):
+ captured.update(kwargs)
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+ import setup
+ importlib.reload(setup)
+ install_requires = captured.get("install_requires")
+ assert isinstance(install_requires, list)
+ assert len(install_requires) == 3
+ assert "sqlalchemy >= 2.0.15, < 3.0.0" in install_requires
+ assert "alembic >= 1.11.1" in install_requires
+ assert "banal >= 1.0.1" in install_requires
+ def test_long_description_content_type(self, monkeypatch, tmp_path):
+ """Test that long_description_content_type is set to 'text/markdown'."""
+ readme = tmp_path / "README.md"
+ readme.write_text("Dummy content for long_description_content_type test")
+ monkeypatch.chdir(tmp_path)
+ captured = {}
+ def dummy_setup(**kwargs):
+ captured.update(kwargs)
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+ import setup
+ importlib.reload(setup)
+ assert captured.get("long_description_content_type") == "text/markdown"
+
+ def test_entry_points(self, monkeypatch, tmp_path):
+ """Test that the entry_points field is an empty dictionary."""
+ readme = tmp_path / "README.md"
+ readme.write_text("Dummy content for entry_points test")
+ monkeypatch.chdir(tmp_path)
+ captured = {}
+ def dummy_setup(**kwargs):
+ captured.update(kwargs)
+ monkeypatch.setattr(setuptools, "setup", dummy_setup)
+ import sys
+ if "setup" in sys.modules:
+ del sys.modules["setup"]
+ import setup
+ importlib.reload(setup)
+ assert captured.get("entry_points") == {}
\ No newline at end of file
diff --git a/tests/test_util.py b/tests/test_util.py
new file mode 100644
index 0000000..41a2889
--- /dev/null
+++ b/tests/test_util.py
@@ -0,0 +1,304 @@
+import pytest
+from collections import namedtuple, OrderedDict
+from urllib.parse import urlparse, parse_qs
+from hashlib import sha1
+from dataset.util import (
+ convert_row,
+ iter_result_proxy,
+ make_sqlite_url,
+ ResultIter,
+ normalize_column_name,
+ normalize_column_key,
+ normalize_table_name,
+ safe_url,
+ index_name,
+ pad_chunk_columns,
+ QUERY_STEP,
+ row_type
+)
+
+class DummyResultProxy:
+ """A dummy result proxy to simulate fetchall and fetchmany behavior."""
+ def __init__(self, chunks):
+ self.chunks = chunks
+ self.index = 0
+ def fetchall(self):
+ if self.index < len(self.chunks):
+ result = self.chunks[self.index]
+ self.index += 1
+ return result
+ return []
+ def fetchmany(self, size):
+ if self.index < len(self.chunks):
+ result = self.chunks[self.index]
+ self.index += 1
+ return result
+ return []
+
+class DummyClosedCursor:
+ """A dummy cursor that simulates an already closed resource."""
+ def keys(self):
+ from sqlalchemy.exc import ResourceClosedError
+ raise ResourceClosedError
+ def close(self):
+ pass
+
+class TestUtil:
+ def test_convert_row_valid(self):
+ """Test convert_row with a valid row object."""
+ DummyRow = namedtuple("DummyRow", ["a", "b"])
+ row_obj = DummyRow(1, "x")
+ result = convert_row(OrderedDict, row_obj)
+ assert result == OrderedDict([("a", 1), ("b", "x")])
+
+ def test_convert_row_none(self):
+ """Test convert_row with None as row."""
+ result = convert_row(OrderedDict, None)
+ assert result is None
+
+ def test_iter_result_proxy_fetchall(self):
+ """Test iter_result_proxy using fetchall."""
+ DummyRow = namedtuple("DummyRow", ["a", "b"])
+ rows = [DummyRow(1, "x"), DummyRow(2, "y")]
+ rp = DummyResultProxy(chunks=[rows])
+ iterator = iter_result_proxy(rp)
+ collected = list(iterator)
+ assert collected == rows
+
+ def test_iter_result_proxy_fetchmany(self):
+ """Test iter_result_proxy using fetchmany with step size."""
+ DummyRow = namedtuple("DummyRow", ["a", "b"])
+ rows_part1 = [DummyRow(1, "x")]
+ rows_part2 = [DummyRow(2, "y")]
+ rp = DummyResultProxy(chunks=[rows_part1, rows_part2])
+ iterator = iter_result_proxy(rp, step=1)
+ collected = list(iterator)
+ assert collected == rows_part1 + rows_part2
+
+ def test_make_sqlite_url_default(self):
+ """Test make_sqlite_url with default parameters."""
+ url = make_sqlite_url("test.db")
+ assert url == "sqlite:///test.db"
+
+ def test_make_sqlite_url_with_params(self):
+ """Test make_sqlite_url with various parameters."""
+ url = make_sqlite_url("test.db", cache="shared", timeout=30, mode="ro", check_same_thread=False, immutable=True, nolock=True)
+ # URL should start with 'sqlite:///file:test.db?' and include parameters
+ assert url.startswith("sqlite:///file:test.db?")
+ query = url.split("?", 1)[1]
+ params = parse_qs(query)
+ assert params.get("cache") == ["shared"]
+ assert params.get("timeout") == ["30"]
+ assert params.get("mode") == ["ro"]
+ assert params.get("nolock") == ["1"]
+ assert params.get("immutable") == ["1"]
+ assert params.get("check_same_thread") == ["false"]
+ assert params.get("uri") == ["true"]
+
+ def test_result_iter_normal(self):
+ """Test ResultIter normal behavior."""
+ DummyRow = namedtuple("DummyRow", ["a", "b"])
+ rows = [DummyRow(1, "x"), DummyRow(2, "y")]
+ # Create a dummy cursor for ResultIter
+ class DummyCursorForResultIter:
+ def __init__(self, rows):
+ self.rows = rows
+ self.index = 0
+ def keys(self):
+ return list(self.rows[0]._fields) if self.rows else []
+ def fetchall(self):
+ if self.index == 0:
+ self.index = len(self.rows)
+ return self.rows
+ return []
+ def close(self):
+ self.closed = True
+ cursor = DummyCursorForResultIter(rows)
+ iter_obj = ResultIter(cursor, row_type=OrderedDict)
+ collected = list(iter_obj)
+ expected = [OrderedDict(zip(r._fields, r)) for r in rows]
+ assert collected == expected
+
+ def test_result_iter_closed(self):
+ """Test ResultIter with a closed cursor that raises ResourceClosedError."""
+ cursor = DummyClosedCursor()
+ iter_obj = ResultIter(cursor)
+ # Since the cursor raises ResourceClosedError, keys should be empty and iteration should be empty
+ assert iter_obj.keys == []
+ with pytest.raises(StopIteration):
+ next(iter_obj)
+
+ def test_normalize_column_name_valid(self):
+ """Test normalize_column_name with valid name."""
+ valid_name = " column_name "
+ normalized = normalize_column_name(valid_name)
+ # Should be trimmed and valid (not containing dots or dashes and within 63 characters)
+ assert normalized == "column_name"
+
+ def test_normalize_column_name_invalid(self):
+ """Test normalize_column_name with invalid names."""
+ with pytest.raises(ValueError):
+ normalize_column_name("column.name")
+ with pytest.raises(ValueError):
+ normalize_column_name("column-name")
+ with pytest.raises(ValueError):
+ normalize_column_name(" ") # empty after stripping
+
+ def test_normalize_column_key(self):
+ """Test normalize_column_key transforms to uppercase and removes spaces."""
+ assert normalize_column_key("column key") == "COLUMNKEY"
+ assert normalize_column_key(" Column ") == "COLUMN"
+ assert normalize_column_key(None) is None
+
+ def test_normalize_table_name_valid(self):
+ """Test normalize_table_name with a valid table name."""
+ table_name = " MyTable "
+ normalized = normalize_table_name(table_name)
+ assert normalized == "MyTable"
+
+ def test_normalize_table_name_invalid(self):
+ """Test normalize_table_name with an invalid table name."""
+ with pytest.raises(ValueError):
+ normalize_table_name(" ")
+ with pytest.raises(ValueError):
+ normalize_table_name(123)
+
+ def test_safe_url(self):
+ """Test safe_url to ensure the password is hidden."""
+ url_with_pwd = "postgresql://user:secret@localhost/db"
+ safe = safe_url(url_with_pwd)
+ assert "secret" not in safe
+ assert "*****" in safe
+ # URL without a password should remain unchanged
+ url_no_pwd = "postgresql://user@localhost/db"
+ assert safe_url(url_no_pwd) == url_no_pwd
+
+ def test_index_name(self):
+ """Test index_name generates a valid index name."""
+ table = "mytable"
+ columns = ["col1", "col2"]
+ ix = index_name(table, columns)
+ # Check that the index name starts with "ix_mytable_" and has proper length.
+ assert ix.startswith("ix_mytable_")
+ key_part = ix.split("_")[-1]
+ assert len(key_part) == 16
+
+ def test_pad_chunk_columns(self):
+ """Test pad_chunk_columns ensures each record has the required columns."""
+ chunk = [{"a": 1}, {"b": 2}, {"a": 3, "c": 4}]
+ columns = ["a", "b", "c"]
+ padded = pad_chunk_columns(chunk, columns)
+ for record in padded:
+ for col in columns:
+ assert col in record
+ # Check that missing columns are padded with None.
+ assert padded[0]["b"] is None
+ assert padded[0]["c"] is None
+ assert padded[1]["a"] is None
+ assert padded[1]["c"] is None
+
+ def test_normalize_column_name_long(self):
+ """Test that normalize_column_name trims long names to fit 63 chars and ensures utf-8 byte length is less than 64."""
+ long_name = "a" * 100
+ normalized = normalize_column_name(long_name)
+ # For ASCII all characters are 1 byte, so should be exactly 63 characters
+ assert len(normalized) == 63
+
+ def test_normalize_column_name_non_string(self):
+ """Test that normalize_column_name raises ValueError for non-string inputs."""
+ with pytest.raises(ValueError):
+ normalize_column_name(12345)
+
+ def test_make_sqlite_url_timeout_only(self):
+ """Test make_sqlite_url with only the timeout parameter provided."""
+ url = make_sqlite_url("test_timeout.db", timeout=10)
+ assert url.startswith("sqlite:///file:test_timeout.db?")
+ query = url.split("?", 1)[1]
+ params = parse_qs(query)
+ assert params.get("timeout") == ["10"]
+ assert params.get("uri") == ["true"]
+
+ def test_pad_chunk_columns_empty_chunk(self):
+ """Test that pad_chunk_columns returns an empty list when provided an empty chunk."""
+ chunk = []
+ columns = ["a", "b"]
+ padded = pad_chunk_columns(chunk, columns)
+ assert padded == []
+
+ def test_index_name_empty_columns(self):
+ """Test index_name with an empty list of columns produces a valid index name."""
+ ix = index_name("table", [])
+ expected_hash = sha1("".encode("utf-8")).hexdigest()[:16]
+ expected = "ix_table_" + expected_hash
+ assert ix == expected
+
+ def test_result_iter_next_after_iteration(self):
+ """Test that calling next() after exhausting the ResultIter raises StopIteration consistently."""
+ DummyRow = namedtuple("DummyRow", ["a", "b"])
+ rows = [DummyRow(1, "x")]
+ class DummyCursorForNextTest:
+ def __init__(self, rows):
+ self.rows = rows
+ self.index = 0
+ def keys(self):
+ return list(self.rows[0]._fields) if self.rows else []
+ def fetchall(self):
+ if self.index == 0:
+ self.index = len(self.rows)
+ return self.rows
+ return []
+ def close(self):
+ self.closed = True
+ cursor = DummyCursorForNextTest(rows)
+ iter_obj = ResultIter(cursor, row_type=OrderedDict)
+ # Exhaust the iterator
+ list(iter_obj)
+ with pytest.raises(StopIteration):
+ next(iter_obj)
+
+ def test_iter_result_proxy_empty_initially(self):
+ pass
+ def test_result_iter_calls_close(self):
+ """Test that ResultIter.close() is called after iteration is exhausted."""
+ DummyRow = namedtuple("DummyRow", ["a"])
+ rows = [DummyRow(1)]
+ closed_flag = {"closed": False}
+ class DummyCursorForClose:
+ def __init__(self, rows):
+ self.rows = rows
+ self.index = 0
+ def keys(self):
+ return list(self.rows[0]._fields) if self.rows else []
+ def fetchall(self):
+ if self.index == 0:
+ self.index = len(self.rows)
+ return self.rows
+ return []
+ def close(self):
+ closed_flag["closed"] = True
+ cursor = DummyCursorForClose(rows)
+ iter_obj = ResultIter(cursor, row_type=OrderedDict)
+ # Force the iterator to exhaust
+ list(iter_obj)
+ assert closed_flag["closed"] is True
+
+ def test_normalize_column_name_multibyte(self):
+ """Test that normalize_column_name trims multibyte names to adhere to the utf-8 byte limit."""
+ # Use a multibyte character (e.g., "é": 2 bytes in utf-8); 70 instances
+ multi_byte_name = "é" * 70
+ normalized = normalize_column_name(multi_byte_name)
+ # Check that the utf-8 encoding of normalized is less than 64 bytes
+ assert len(normalized.encode("utf-8")) < 64
+
+ def test_convert_row_invalid(self):
+ """Test convert_row with an object that does not have a _fields attribute."""
+ class NoFields:
+ pass
+ with pytest.raises(AttributeError):
+ convert_row(OrderedDict, NoFields())
+ """Test that iter_result_proxy returns an empty iterator when no rows are fetched."""
+ rp = DummyResultProxy(chunks=[])
+ iterator = iter_result_proxy(rp)
+ collected = list(iterator)
+ assert collected == []
+# End of tests for dataset/util.py
\ No newline at end of file