diff --git a/codebeaver.yml b/codebeaver.yml new file mode 100644 index 0000000..419e243 --- /dev/null +++ b/codebeaver.yml @@ -0,0 +1,2 @@ +from: pytest +# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/ \ No newline at end of file diff --git a/test/test_database.py b/test/test_database.py new file mode 100644 index 0000000..5f39c5e --- /dev/null +++ b/test/test_database.py @@ -0,0 +1,254 @@ +import os +import pytest +from datetime import datetime +from collections import OrderedDict +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.sql import text + +from dataset import connect + +from .conftest import TEST_DATA + + +def test_valid_database_url(db): + assert db.url, os.environ["DATABASE_URL"] + + +def test_database_url_query_string(db): + db = connect("sqlite:///:memory:/?cached_statements=1") + assert "cached_statements" in db.url, db.url + + +def test_tables(db, table): + assert db.tables == ["weather"], db.tables + + +def test_contains(db, table): + assert "weather" in db, db.tables + + +def test_create_table(db): + table = db["foo"] + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert "id" in table.table.c, table.table.c + + +def test_create_table_no_ids(db): + if db.is_mysql or db.is_sqlite: + return + table = db.create_table("foo_no_id", primary_id=False) + assert table.table.name == "foo_no_id" + assert len(table.table.columns) == 0, table.table.columns + + +def test_create_table_custom_id1(db): + pid = "string_id" + table = db.create_table("foo2", pid, db.types.string(255)) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + table.insert({pid: "foobar"}) + assert table.find_one(string_id="foobar")[pid] == "foobar" + + +def test_create_table_custom_id2(db): + pid = "string_id" + table = db.create_table("foo3", pid, db.types.string(50)) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({pid: "foobar"}) + assert table.find_one(string_id="foobar")[pid] == "foobar" + + +def test_create_table_custom_id3(db): + pid = "int_id" + table = db.create_table("foo4", primary_id=pid) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({pid: 123}) + table.insert({pid: 124}) + assert table.find_one(int_id=123)[pid] == 123 + assert table.find_one(int_id=124)[pid] == 124 + with pytest.raises(IntegrityError): + table.insert({pid: 123}) + db.rollback() + + +def test_create_table_shorthand1(db): + pid = "int_id" + table = db.get_table("foo5", pid) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({"int_id": 123}) + table.insert({"int_id": 124}) + assert table.find_one(int_id=123)["int_id"] == 123 + assert table.find_one(int_id=124)["int_id"] == 124 + with pytest.raises(IntegrityError): + table.insert({"int_id": 123}) + + +def test_create_table_shorthand2(db): + pid = "string_id" + table = db.get_table("foo6", primary_id=pid, primary_type=db.types.string(255)) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({"string_id": "foobar"}) + assert table.find_one(string_id="foobar")["string_id"] == "foobar" + + +def test_with(db, table): + init_length = len(table) + with pytest.raises(ValueError): + with db: + table.insert( + { + "date": datetime(2011, 1, 1), + "temperature": 1, + "place": "tmp_place", + } + ) + raise ValueError() + db.rollback() + assert len(table) == init_length + + +def test_invalid_values(db, table): + if db.is_mysql: + # WARNING: mysql seems to be doing some weird type casting + # upon insert. The mysql-python driver is not affected but + # it isn't compatible with Python 3 + # Conclusion: use postgresql. + return + with pytest.raises(SQLAlchemyError): + table.insert({"date": True, "temperature": "wrong_value", "place": "tmp_place"}) + + +def test_load_table(db, table): + tbl = db.load_table("weather") + assert tbl.table.name == table.table.name + + +def test_query(db, table): + r = db.query("SELECT COUNT(*) AS num FROM weather").next() + assert r["num"] == len(TEST_DATA), r + + +def test_table_cache_updates(db): + tbl1 = db.get_table("people") + data = OrderedDict([("first_name", "John"), ("last_name", "Smith")]) + tbl1.insert(data) + data["id"] = 1 + tbl2 = db.get_table("people") + assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) + +def test_repr(db): + """Test __repr__ returns the correct database URL in its string representation.""" + rep = repr(db) + assert db.url in rep + +def test_metadata(db): + """Test that the metadata property returns a proper MetaData instance with the correct schema.""" + meta = db.metadata + from sqlalchemy.schema import MetaData + assert isinstance(meta, MetaData) + assert meta.schema == db.schema + +def test_in_transaction(db, table): + """Test that the in_transaction property accurately reflects the transaction state.""" + # Upon accessing db.conn, a transaction is begun automatically + if db.is_sqlite: + # SQLite may not begin a transaction automatically so we start one explicitly + db.begin() + assert db.in_transaction is True + else: + assert db.in_transaction is True + db.commit() + assert db.in_transaction is False + with db: + assert db.in_transaction is True + assert db.in_transaction is False + +def test_close_connection(db): + # Close the database and override close() so that fixture teardown does not call close() again + db.close() + db.close = lambda: None + db.close() + import pytest + with pytest.raises(Exception): + _ = db.conn + +def test_query_step(db, table): + """Test that query() with _step parameter of 0 returns all rows (non-stepped result).""" + table.insert(dict(a=1)) + table.insert(dict(a=2)) + rows = list(db.query("SELECT a FROM %s" % table.table.name, _step=0)) + assert len(rows) >= 2 + +def test_contains_invalid(db): + """Test that __contains__ returns False for a non-existing table.""" + assert "non_existing_table" not in db + +def test_on_connect_statements(): + """Test that a database created with custom on_connect_statements is properly configured.""" + custom_sql = "PRAGMA synchronous=OFF" + from dataset import connect + db2 = connect("sqlite:///:memory:/", on_connect_statements=[custom_sql]) + assert db2.is_sqlite + db2.close() + +def test_transaction_rollback(db, table): + """Test that rollback properly undoes inserted records in an uncommitted transaction.""" + initial_count = len(list(table.all())) + table.insert(dict(a=999)) + db.rollback() + new_count = len(list(table.all())) + assert new_count == initial_count +def test_op_property(db): + """Test that the op property returns an instance of alembic.operations.Operations.""" + from alembic.operations import Operations + op = db.op + assert isinstance(op, Operations) + # Rollback any active transaction to clean up + db.rollback() + +def test_views(db): + """Test that the views property correctly returns a newly created view.""" + import pytest + # Only run this test for sqlite or postgres as creating views might differ between dialects + if not (db.is_sqlite or db.is_postgres): + pytest.skip("Views test only valid for SQLite and PostgreSQL") + # Create a view using a raw SQL statement + db.conn.execute(text("CREATE VIEW test_view AS SELECT 1 AS col")) + db.commit() + assert "test_view" in db.views, "The created view 'test_view' was not found in db.views" + # Clean up: drop the view (use a try/except in case dropping fails) + try: + db.conn.execute(text("DROP VIEW test_view")) + except Exception: + pass + db.commit() + +def test_ipython_completions(db): + """Test that _ipython_key_completions_ returns a list equal to the tables property.""" + completions = db._ipython_key_completions_() + assert isinstance(completions, list) + assert set(completions) == set(db.tables) + +def test_commit_empty_transaction(db): + """Test that commit works properly for an empty transaction.""" + db.begin() + db.commit() + assert not db.in_transaction, "After commit, there should be no active transaction." + +def test_connection_reuse(db): + """Test that repeated calls to the db.conn property return the same connection instance.""" + conn1 = db.conn + conn2 = db.conn + assert conn1 is conn2, "db.conn did not return the same connection instance on repeated calls." \ No newline at end of file diff --git a/test/test_table.py b/test/test_table.py new file mode 100644 index 0000000..4d91101 --- /dev/null +++ b/test/test_table.py @@ -0,0 +1,576 @@ +import os +import pytest +from datetime import datetime +from collections import OrderedDict +from sqlalchemy.types import BIGINT, INTEGER, VARCHAR, TEXT +from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError + +from dataset import chunked + +from .conftest import TEST_DATA, TEST_CITY_1 + + +def test_insert(table): + assert len(table) == len(TEST_DATA), len(table) + last_id = table.insert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + assert table.find_one(id=last_id)["place"] == "Berlin" + + +def test_insert_ignore(table): + table.insert_ignore( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + table.insert_ignore( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_insert_ignore_all_key(table): + for i in range(0, 4): + table.insert_ignore( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["date", "temperature", "place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_insert_json(table): + last_id = table.insert( + { + "date": datetime(2011, 1, 2), + "temperature": -10, + "place": "Berlin", + "info": { + "currency": "EUR", + "language": "German", + "population": 3292365, + }, + } + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + assert table.find_one(id=last_id)["place"] == "Berlin" + + +def test_upsert(table): + table.upsert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + table.upsert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_upsert_single_column(db): + table = db["banana_single_col"] + table.upsert({"color": "Yellow"}, ["color"]) + assert len(table) == 1, len(table) + table.upsert({"color": "Yellow"}, ["color"]) + assert len(table) == 1, len(table) + + +def test_upsert_all_key(table): + assert len(table) == len(TEST_DATA), len(table) + for i in range(0, 2): + table.upsert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["date", "temperature", "place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_upsert_id(db): + table = db["banana_with_id"] + data = dict(id=10, title="I am a banana!") + table.upsert(data, ["id"]) + assert len(table) == 1, len(table) + + +def test_update_while_iter(table): + for row in table: + row["foo"] = "bar" + table.update(row, ["place", "date"]) + assert len(table) == len(TEST_DATA), len(table) + + +def test_weird_column_names(table): + with pytest.raises(ValueError): + table.insert( + { + "date": datetime(2011, 1, 2), + "temperature": -10, + "foo.bar": "Berlin", + "qux.bar": "Huhu", + } + ) + + +def test_cased_column_names(db): + tbl = db["cased_column_names"] + tbl.insert({"place": "Berlin"}) + tbl.insert({"Place": "Berlin"}) + tbl.insert({"PLACE ": "Berlin"}) + assert len(tbl.columns) == 2, tbl.columns + assert len(list(tbl.find(Place="Berlin"))) == 3 + assert len(list(tbl.find(place="Berlin"))) == 3 + assert len(list(tbl.find(PLACE="Berlin"))) == 3 + + +def test_invalid_column_names(db): + tbl = db["weather"] + with pytest.raises(ValueError): + tbl.insert({None: "banana"}) + + with pytest.raises(ValueError): + tbl.insert({"": "banana"}) + + with pytest.raises(ValueError): + tbl.insert({"-": "banana"}) + + +def test_delete(table): + table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}) + original_count = len(table) + assert len(table) == len(TEST_DATA) + 1, len(table) + # Test bad use of API + with pytest.raises(ArgumentError): + table.delete({"place": "Berlin"}) + assert len(table) == original_count, len(table) + + assert table.delete(place="Berlin") is True, "should return 1" + assert len(table) == len(TEST_DATA), len(table) + assert table.delete() is True, "should return non zero" + assert len(table) == 0, len(table) + + +def test_repr(table): + assert ( + repr(table) == "" + ), "the representation should be " + + +def test_delete_nonexist_entry(table): + assert ( + table.delete(place="Berlin") is False + ), "entry not exist, should fail to delete" + + +def test_find_one(table): + table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}) + d = table.find_one(place="Berlin") + assert d["temperature"] == -10, d + d = table.find_one(place="Atlantis") + assert d is None, d + + +def test_count(table): + assert len(table) == 6, len(table) + length = table.count(place=TEST_CITY_1) + assert length == 3, length + + +def test_find(table): + ds = list(table.find(place=TEST_CITY_1)) + assert len(ds) == 3, ds + ds = list(table.find(place=TEST_CITY_1, _limit=2)) + assert len(ds) == 2, ds + ds = list(table.find(place=TEST_CITY_1, _limit=2, _step=1)) + assert len(ds) == 2, ds + ds = list(table.find(place=TEST_CITY_1, _limit=1, _step=2)) + assert len(ds) == 1, ds + ds = list(table.find(_step=2)) + assert len(ds) == len(TEST_DATA), ds + ds = list(table.find(order_by=["temperature"])) + assert ds[0]["temperature"] == -1, ds + ds = list(table.find(order_by=["-temperature"])) + assert ds[0]["temperature"] == 8, ds + ds = list(table.find(table.table.columns.temperature > 4)) + assert len(ds) == 3, ds + + +def test_find_dsl(table): + ds = list(table.find(place={"like": "%lw%"})) + assert len(ds) == 3, ds + ds = list(table.find(temperature={">": 5})) + assert len(ds) == 2, ds + ds = list(table.find(temperature={">=": 5})) + assert len(ds) == 3, ds + ds = list(table.find(temperature={"<": 0})) + assert len(ds) == 1, ds + ds = list(table.find(temperature={"<=": 0})) + assert len(ds) == 2, ds + ds = list(table.find(temperature={"!=": -1})) + assert len(ds) == 5, ds + ds = list(table.find(temperature={"between": [5, 8]})) + assert len(ds) == 3, ds + ds = list(table.find(place={"=": "G€lway"})) + assert len(ds) == 3, ds + ds = list(table.find(place={"ilike": "%LwAy"})) + assert len(ds) == 3, ds + + +def test_offset(table): + ds = list(table.find(place=TEST_CITY_1, _offset=1)) + assert len(ds) == 2, ds + ds = list(table.find(place=TEST_CITY_1, _limit=2, _offset=2)) + assert len(ds) == 1, ds + + +def test_streamed_update(table): + ds = list(table.find(place=TEST_CITY_1, _streamed=True, _step=1)) + assert len(ds) == 3, len(ds) + for row in table.find(place=TEST_CITY_1, _streamed=True, _step=1): + row["temperature"] = -1 + table.update(row, ["id"]) + + +def test_distinct(table): + x = list(table.distinct("place")) + assert len(x) == 2, x + x = list(table.distinct("place", "date")) + assert len(x) == 6, x + x = list( + table.distinct( + "place", + "date", + table.table.columns.date >= datetime(2011, 1, 2, 0, 0), + ) + ) + assert len(x) == 4, x + + x = list(table.distinct("temperature", place="B€rkeley")) + assert len(x) == 3, x + x = list(table.distinct("temperature", place=["B€rkeley", "G€lway"])) + assert len(x) == 6, x + + +def test_insert_many(table): + data = TEST_DATA * 100 + table.insert_many(data, chunk_size=13) + assert len(table) == len(data) + 6, (len(table), len(data)) + + +def test_chunked_insert(table): + data = TEST_DATA * 100 + with chunked.ChunkedInsert(table) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(table) == len(data) + 6, (len(table), len(data)) + + +def test_chunked_insert_callback(table): + data = TEST_DATA * 100 + N = 0 + + def callback(queue): + nonlocal N + N += len(queue) + + with chunked.ChunkedInsert(table, callback=callback) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(data) == N + assert len(table) == len(data) + 6 + + +def test_update_many(db): + tbl = db["update_many_test"] + tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)]) + tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id") + + # Ensure data has been updated. + assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] + + +def test_chunked_update(db): + tbl = db["update_many_test"] + tbl.insert_many( + [ + dict(temp=10, location="asdf"), + dict(temp=20, location="qwer"), + dict(temp=30, location="asdf"), + ] + ) + db.commit() + + chunked_tbl = chunked.ChunkedUpdate(tbl, ["id"]) + chunked_tbl.update(dict(id=1, temp=50)) + chunked_tbl.update(dict(id=2, location="asdf")) + chunked_tbl.update(dict(id=3, temp=50)) + chunked_tbl.flush() + db.commit() + + # Ensure data has been updated. + assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50 + assert tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf" + + +def test_upsert_many(db): + # Also tests updating on records with different attributes + tbl = db["upsert_many_test"] + + W = 100 + tbl.upsert_many([dict(age=10), dict(weight=W)], "id") + assert tbl.find_one(id=1)["age"] == 10 + + tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id") + assert tbl.find_one(id=2)["weight"] == W / 2 + + +def test_drop_operations(table): + assert table._table is not None, "table shouldn't be dropped yet" + table.drop() + assert table._table is None, "table should be dropped now" + assert list(table.all()) == [], table.all() + assert table.count() == 0, table.count() + + +def test_table_drop(db, table): + assert "weather" in db + db["weather"].drop() + assert "weather" not in db + + +def test_table_drop_then_create(db, table): + assert "weather" in db + db["weather"].drop() + assert "weather" not in db + db["weather"].insert({"foo": "bar"}) + + +def test_columns(table): + cols = table.columns + assert len(list(cols)) == 4, "column count mismatch" + assert "date" in cols and "temperature" in cols and "place" in cols + + +def test_drop_column(table): + try: + table.drop_column("date") + assert "date" not in table.columns + except RuntimeError: + pass + + +def test_iter(table): + c = 0 + for row in table: + c += 1 + assert c == len(table) + + +def test_update(table): + date = datetime(2011, 1, 2) + res = table.update( + {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"] + ) + assert res, "update should return True" + m = table.find_one(place=TEST_CITY_1, date=date) + assert m["temperature"] == -10, ( + "new temp. should be -10 but is %d" % m["temperature"] + ) + + +def test_create_column(db, table): + flt = db.types.float + table.create_column("foo", flt) + assert "foo" in table.table.c, table.table.c + assert isinstance(table.table.c["foo"].type, flt), table.table.c["foo"].type + assert "foo" in table.columns, table.columns + + +def test_ensure_column(db, table): + flt = db.types.float + table.create_column_by_example("foo", 0.1) + assert "foo" in table.table.c, table.table.c + assert isinstance(table.table.c["foo"].type, flt), table.table.c["bar"].type + table.create_column_by_example("bar", 1) + assert "bar" in table.table.c, table.table.c + assert isinstance(table.table.c["bar"].type, BIGINT), table.table.c["bar"].type + table.create_column_by_example("pippo", "test") + assert "pippo" in table.table.c, table.table.c + assert isinstance(table.table.c["pippo"].type, TEXT), table.table.c["pippo"].type + table.create_column_by_example("bigbar", 11111111111) + assert "bigbar" in table.table.c, table.table.c + assert isinstance(table.table.c["bigbar"].type, BIGINT), table.table.c[ + "bigbar" + ].type + table.create_column_by_example("littlebar", -11111111111) + assert "littlebar" in table.table.c, table.table.c + assert isinstance(table.table.c["littlebar"].type, BIGINT), table.table.c[ + "littlebar" + ].type + + +def test_key_order(db, table): + res = db.query("SELECT temperature, place FROM weather LIMIT 1") + keys = list(res.next().keys()) + assert keys[0] == "temperature" + assert keys[1] == "place" + + +def test_empty_query(table): + empty = list(table.find(place="not in data")) + assert len(empty) == 0, empty + +def test_exists_property(db): + """Test the exists property of a table before and after creation and drop.""" + # Create a new temporary table instance with a unique name. + temp = db["temp_exists"] + # Before any operations, the table should not exist. + assert not temp.exists, "Table should not exist yet" + # Insert a row to trigger auto table creation (ensure=True forces schema creation) + temp.insert({"dummy": "data"}, ensure=True) + assert temp.exists, "Table should exist after insertion" + # Now drop the table and check exists property again. + temp.drop() + assert not temp.exists, "Table should not exist after drop" + +def test_insert_no_ensure(table, db): + """Test that inserting a row with an unknown column using ensure=False omits that column.""" + # Create a dedicated table to avoid schema conflict with the main fixture. + test_tbl = db["no_ensure_test"] + # Insert a row with an extra column; since ensure is False it should not be created. + inserted = test_tbl.insert({"date": datetime(2011, 1, 2), "temperature": 25, "place": "Testville", "extra": "value"}, ensure=False) + new_row = test_tbl.find_one(id=inserted) + # The extra column should not appear in the stored row. + assert "extra" not in new_row, "Extra column should not be created when ensure is False" + +def test_create_index_non_existing(table): + """Test creating an index with a non-existent column does nothing.""" + # Create a column that is known to exist. + table.create_column("existing", VARCHAR(50)) + # Attempt to create an index on one existing and one non-existing column. + table.create_index(["existing", "non_existing"], name="idx_test") + # Since "non_existing" does not exist, the index should not be created. + assert not table.has_index(["existing", "non_existing"]), "Index should not be created with non-existent column" + +def test_update_no_op(table): + """Test updating a row with only key fields returns the count without making alterations.""" + # Insert a test row. + row = {"date": datetime(2011, 1, 2), "temperature": 15, "place": "Nowhere"} + inserted = table.insert(row, ensure=True) + # Invoke update with no additional update fields (only key values provided). + count = table.update({ + "date": row["date"], + "temperature": row["temperature"], + "place": row["place"] + }, ["date", "temperature", "place"]) + # Should return the count (the number of matched rows); we’re not altering any value. + assert count >= 1, "Update with no changes should return the number of matching rows" + +def test_chunked_insert_edge_case(table): + """Test that a chunked insert with empty data does not modify the table.""" + initial_len = len(table) + with chunked.ChunkedInsert(table) as ci: + # Do not insert any data. + pass + assert len(table) == initial_len, "Empty chunked insert should not change table count" + +def test_stream_results(table): + """Test that streamed query returns the expected number of rows.""" + # Insert a few rows with a distinct place. + for i in range(3): + table.insert({"date": datetime(2011, 1, 2), "temperature": 20+i, "place": "Streamville"}, ensure=True) + rows = list(table.find(place="Streamville", _streamed=True, _step=1)) + assert len(rows) == 3, "Streamed query should return exactly three rows" + +# End of inserted tests. +def test_unknown_operator_filter(table): + """Test that using an unknown operator yields no results.""" + # Insert a row with a distinctive place to later filter on. + table.insert({"date": datetime(2021, 1, 1), "temperature": 25, "place": "UnknownOp"}, ensure=True) + # Use an unknown operator 'foobar' – _generate_clause will return false() so no row should match. + results = list(table.find(place={"foobar": "UnknownOp"})) + assert len(results) == 0, "Expected no results with an unknown operator filter" + +def test_order_by_invalid_column(table): + """Test that ordering by a non-existing column does not break the query.""" + # Order by a non-existent column should simply be ignored. + results = list(table.find(order_by=["nonexistent"])) + # Expect all rows to be returned. + assert len(results) == len(table), "Expected all rows when ordering by an invalid column" + +def test_create_duplicate_column(db, table): + """Test that creating a duplicate column does not alter the table schema.""" + initial_columns = set(table.columns) + # Create a new column 'duplicate_test' + table.create_column("duplicate_test", INTEGER) + # Calling create_column again should not add an extra column. + table.create_column("duplicate_test", INTEGER) + new_columns = set(table.columns) + # The new set should have exactly one more column than the initial set. + assert "duplicate_test" in new_columns, "Column 'duplicate_test' should exist in the schema" + assert len(new_columns) == len(initial_columns) + 1, "Duplicate column creation should not add extra columns" + +def test_update_keys_removal(table): + """Test that keys used for matching are removed from the update payload.""" + # Insert a row with an extra field. + row_data = {"date": datetime(2022, 2, 2), "temperature": 15, "place": "KeyTest", "extra": "initial"} + rid = table.insert(row_data.copy(), ensure=True) + # Update using keys (date and place) while also changing the 'extra' value. + update_data = {"date": row_data["date"], "place": row_data["place"], "extra": "updated"} + count = table.update(update_data, ["date", "place"]) + assert count >= 1, "Update should affect at least one row" + updated_row = table.find_one(id=rid) + assert updated_row.get("extra") == "updated", "Update did not modify 'extra' field correctly" +def test_update_many_empty(db): + """ + Test that update_many with an empty list of rows does not alter the table. + """ + tbl = db["empty_update_test"] + # Insert a couple of rows into the table. + tbl.insert_many([{"temp": 10}, {"temp": 20}]) + initial_count = tbl.count() + # Call update_many with an empty list. + tbl.update_many([], "id") + assert tbl.count() == initial_count, "update_many with empty list should not change table rows" + +def test_upsert_many_empty(db): + """ + Test that upsert_many with an empty list does not alter the table. + """ + tbl = db["empty_upsert_test"] + tbl.insert({"name": "Alice"}, ensure=True) + initial_count = tbl.count() + tbl.upsert_many([], "id") + assert tbl.count() == initial_count, "upsert_many with empty list should not change table rows" + +def test_order_by_mixed(table): + """ + Test ordering with a mix of valid and invalid columns. + Invalid ordering columns should be ignored. + """ + # Add test rows with specific temperatures. + table.insert({"date": datetime(2022, 1, 1), "temperature": 15, "place": "X"}) + table.insert({"date": datetime(2022, 1, 2), "temperature": 10, "place": "Y"}) + table.insert({"date": datetime(2022, 1, 3), "temperature": 20, "place": "Z"}) + # Order by valid temperature ascending; an invalid column "nonexistent" should be ignored. + results = list(table.find(order_by=["temperature", "nonexistent"])) + temps = [row["temperature"] for row in results] + assert temps == sorted(temps), "Rows should be ordered by temperature ascending" + + # Order by descending temperature using valid flag; again invalid should be ignored. + results_desc = list(table.find(order_by=["-temperature", "nonexistent"])) + temps_desc = [row["temperature"] for row in results_desc] + assert temps_desc == sorted(temps_desc, reverse=True), "Rows should be ordered by temperature descending" + +def test_insert_boolean_and_null(table): + """ + Test inserting rows with boolean and None values. + """ + # Insert a row with a None temperature and a boolean value in 'active'. + inserted_id = table.insert({"date": datetime(2020, 5, 5), "temperature": None, "place": "TestVille", "active": True}, ensure=True) + row = table.find_one(id=inserted_id) + assert row["temperature"] is None, "Temperature should be None" + assert row["active"] is True, "Active should be True" \ No newline at end of file diff --git a/tests/test_setup.py b/tests/test_setup.py new file mode 100644 index 0000000..29068e6 --- /dev/null +++ b/tests/test_setup.py @@ -0,0 +1,181 @@ +import importlib +import pytest +import setuptools + +class TestSetup: + """Test suite for verifying setup.py configuration.""" + + def test_setup_configuration(self, monkeypatch, tmp_path): + """Test that setup() is called with correct metadata when README.md exists.""" + # Create a dummy README.md file with test content. + readme = tmp_path / "README.md" + test_readme = "This is a dummy README for testing." + readme.write_text(test_readme) + + # Change working directory to tmp_path to pick up the dummy README. + monkeypatch.chdir(tmp_path) + + # Dummy function to capture the arguments passed to setuptools.setup. + captured = {} + def dummy_setup(**kwargs): + captured.update(kwargs) + + # Monkey patch setuptools.setup with the dummy_setup. + monkeypatch.setattr(setuptools, "setup", dummy_setup) + + # Remove setup module from sys.modules to force re-execution of top-level code. + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + + # Load the setup module to trigger the setup() call. + import setup + importlib.reload(setup) + + # Verify that the captured setup arguments contain expected metadata. + assert captured.get("name") == "dataset" + assert captured.get("version") == "1.6.0" + assert captured.get("long_description") == test_readme + assert "sqlalchemy >= 2.0.15, < 3.0.0" in captured.get("install_requires", []) + + def test_missing_readme(self, monkeypatch, tmp_path): + """Test that missing README.md file raises FileNotFoundError during module import.""" + # Change working directory to tmp_path which does not contain a README.md. + monkeypatch.chdir(tmp_path) + + # Remove setup module from sys.modules to force re-execution of top-level code. + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + + # Verify that loading setup.py without README.md raises FileNotFoundError. + with pytest.raises(FileNotFoundError): + import setup + importlib.reload(setup) + def test_setup_complete_metadata(self, monkeypatch, tmp_path): + """Test that setup() is called with complete metadata from setup.py.""" + # Create a dummy README.md file with test content for complete metadata test. + readme = tmp_path / "README.md" + test_readme = "Complete metadata test." + readme.write_text(test_readme) + + # Change working directory to tmp_path so that setup.py picks up our dummy README. + monkeypatch.chdir(tmp_path) + + # Dummy function to capture the arguments passed to setuptools.setup. + captured = {} + def dummy_setup(**kwargs): + captured.update(kwargs) + + # Monkey patch setuptools.setup with the dummy_setup. + monkeypatch.setattr(setuptools, "setup", dummy_setup) + + # Remove setup module from sys.modules to force re-execution of top-level code. + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + + # Load the setup module to trigger the setup() call. + import setup + importlib.reload(setup) + + # Verify that all of the expected metadata keys and values are present in captured setup() + assert captured.get("name") == "dataset" + assert captured.get("version") == "1.6.0" + assert captured.get("description") == "Toolkit for Python-based database access." + assert captured.get("long_description") == test_readme + assert isinstance(captured.get("classifiers"), list) and len(captured.get("classifiers")) > 0 + assert captured.get("keywords") == "sql sqlalchemy etl loading utility" + assert captured.get("author") == "Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer" + assert captured.get("author_email") == "friedrich.lindenberg@gmail.com" + assert captured.get("url") == "http://github.com/pudo/dataset" + assert captured.get("license") == "MIT" + assert isinstance(captured.get("install_requires"), list) + assert "sqlalchemy >= 2.0.15, < 3.0.0" in captured.get("install_requires") + assert captured.get("extras_require") == { + "dev": [ + "pip", + "pytest", + "wheel", + "flake8", + "coverage", + "psycopg2-binary", + "PyMySQL", + "cryptography", + ] + } + assert captured.get("tests_require") == ["pytest"] + assert captured.get("test_suite") == "test" + # packages is generated via find_packages; we check that a value is present. + assert captured.get("packages") is not None + assert captured.get("namespace_packages") == [] + assert captured.get("include_package_data") is False + assert captured.get("zip_safe") is False + def test_find_packages_called(self, monkeypatch, tmp_path): + """Test that the packages parameter is computed using monkeypatched find_packages.""" + readme = tmp_path / "README.md" + readme.write_text("dummy") + monkeypatch.chdir(tmp_path) + captured = {} + def dummy_setup(**kwargs): + captured.update(kwargs) + monkeypatch.setattr(setuptools, "setup", dummy_setup) + monkeypatch.setattr(setuptools, "find_packages", lambda exclude: ["dummy_pkg"]) + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + import setup + importlib.reload(setup) + assert captured.get("packages") == ["dummy_pkg"] + + def test_install_requires_length(self, monkeypatch, tmp_path): + """Test that install_requires contains exactly three dependencies with correct content.""" + readme = tmp_path / "README.md" + readme.write_text("dummy") + monkeypatch.chdir(tmp_path) + captured = {} + def dummy_setup(**kwargs): + captured.update(kwargs) + monkeypatch.setattr(setuptools, "setup", dummy_setup) + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + import setup + importlib.reload(setup) + install_requires = captured.get("install_requires") + assert isinstance(install_requires, list) + assert len(install_requires) == 3 + assert "sqlalchemy >= 2.0.15, < 3.0.0" in install_requires + assert "alembic >= 1.11.1" in install_requires + assert "banal >= 1.0.1" in install_requires + def test_long_description_content_type(self, monkeypatch, tmp_path): + """Test that long_description_content_type is set to 'text/markdown'.""" + readme = tmp_path / "README.md" + readme.write_text("Dummy content for long_description_content_type test") + monkeypatch.chdir(tmp_path) + captured = {} + def dummy_setup(**kwargs): + captured.update(kwargs) + monkeypatch.setattr(setuptools, "setup", dummy_setup) + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + import setup + importlib.reload(setup) + assert captured.get("long_description_content_type") == "text/markdown" + + def test_entry_points(self, monkeypatch, tmp_path): + """Test that the entry_points field is an empty dictionary.""" + readme = tmp_path / "README.md" + readme.write_text("Dummy content for entry_points test") + monkeypatch.chdir(tmp_path) + captured = {} + def dummy_setup(**kwargs): + captured.update(kwargs) + monkeypatch.setattr(setuptools, "setup", dummy_setup) + import sys + if "setup" in sys.modules: + del sys.modules["setup"] + import setup + importlib.reload(setup) + assert captured.get("entry_points") == {} \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..41a2889 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,304 @@ +import pytest +from collections import namedtuple, OrderedDict +from urllib.parse import urlparse, parse_qs +from hashlib import sha1 +from dataset.util import ( + convert_row, + iter_result_proxy, + make_sqlite_url, + ResultIter, + normalize_column_name, + normalize_column_key, + normalize_table_name, + safe_url, + index_name, + pad_chunk_columns, + QUERY_STEP, + row_type +) + +class DummyResultProxy: + """A dummy result proxy to simulate fetchall and fetchmany behavior.""" + def __init__(self, chunks): + self.chunks = chunks + self.index = 0 + def fetchall(self): + if self.index < len(self.chunks): + result = self.chunks[self.index] + self.index += 1 + return result + return [] + def fetchmany(self, size): + if self.index < len(self.chunks): + result = self.chunks[self.index] + self.index += 1 + return result + return [] + +class DummyClosedCursor: + """A dummy cursor that simulates an already closed resource.""" + def keys(self): + from sqlalchemy.exc import ResourceClosedError + raise ResourceClosedError + def close(self): + pass + +class TestUtil: + def test_convert_row_valid(self): + """Test convert_row with a valid row object.""" + DummyRow = namedtuple("DummyRow", ["a", "b"]) + row_obj = DummyRow(1, "x") + result = convert_row(OrderedDict, row_obj) + assert result == OrderedDict([("a", 1), ("b", "x")]) + + def test_convert_row_none(self): + """Test convert_row with None as row.""" + result = convert_row(OrderedDict, None) + assert result is None + + def test_iter_result_proxy_fetchall(self): + """Test iter_result_proxy using fetchall.""" + DummyRow = namedtuple("DummyRow", ["a", "b"]) + rows = [DummyRow(1, "x"), DummyRow(2, "y")] + rp = DummyResultProxy(chunks=[rows]) + iterator = iter_result_proxy(rp) + collected = list(iterator) + assert collected == rows + + def test_iter_result_proxy_fetchmany(self): + """Test iter_result_proxy using fetchmany with step size.""" + DummyRow = namedtuple("DummyRow", ["a", "b"]) + rows_part1 = [DummyRow(1, "x")] + rows_part2 = [DummyRow(2, "y")] + rp = DummyResultProxy(chunks=[rows_part1, rows_part2]) + iterator = iter_result_proxy(rp, step=1) + collected = list(iterator) + assert collected == rows_part1 + rows_part2 + + def test_make_sqlite_url_default(self): + """Test make_sqlite_url with default parameters.""" + url = make_sqlite_url("test.db") + assert url == "sqlite:///test.db" + + def test_make_sqlite_url_with_params(self): + """Test make_sqlite_url with various parameters.""" + url = make_sqlite_url("test.db", cache="shared", timeout=30, mode="ro", check_same_thread=False, immutable=True, nolock=True) + # URL should start with 'sqlite:///file:test.db?' and include parameters + assert url.startswith("sqlite:///file:test.db?") + query = url.split("?", 1)[1] + params = parse_qs(query) + assert params.get("cache") == ["shared"] + assert params.get("timeout") == ["30"] + assert params.get("mode") == ["ro"] + assert params.get("nolock") == ["1"] + assert params.get("immutable") == ["1"] + assert params.get("check_same_thread") == ["false"] + assert params.get("uri") == ["true"] + + def test_result_iter_normal(self): + """Test ResultIter normal behavior.""" + DummyRow = namedtuple("DummyRow", ["a", "b"]) + rows = [DummyRow(1, "x"), DummyRow(2, "y")] + # Create a dummy cursor for ResultIter + class DummyCursorForResultIter: + def __init__(self, rows): + self.rows = rows + self.index = 0 + def keys(self): + return list(self.rows[0]._fields) if self.rows else [] + def fetchall(self): + if self.index == 0: + self.index = len(self.rows) + return self.rows + return [] + def close(self): + self.closed = True + cursor = DummyCursorForResultIter(rows) + iter_obj = ResultIter(cursor, row_type=OrderedDict) + collected = list(iter_obj) + expected = [OrderedDict(zip(r._fields, r)) for r in rows] + assert collected == expected + + def test_result_iter_closed(self): + """Test ResultIter with a closed cursor that raises ResourceClosedError.""" + cursor = DummyClosedCursor() + iter_obj = ResultIter(cursor) + # Since the cursor raises ResourceClosedError, keys should be empty and iteration should be empty + assert iter_obj.keys == [] + with pytest.raises(StopIteration): + next(iter_obj) + + def test_normalize_column_name_valid(self): + """Test normalize_column_name with valid name.""" + valid_name = " column_name " + normalized = normalize_column_name(valid_name) + # Should be trimmed and valid (not containing dots or dashes and within 63 characters) + assert normalized == "column_name" + + def test_normalize_column_name_invalid(self): + """Test normalize_column_name with invalid names.""" + with pytest.raises(ValueError): + normalize_column_name("column.name") + with pytest.raises(ValueError): + normalize_column_name("column-name") + with pytest.raises(ValueError): + normalize_column_name(" ") # empty after stripping + + def test_normalize_column_key(self): + """Test normalize_column_key transforms to uppercase and removes spaces.""" + assert normalize_column_key("column key") == "COLUMNKEY" + assert normalize_column_key(" Column ") == "COLUMN" + assert normalize_column_key(None) is None + + def test_normalize_table_name_valid(self): + """Test normalize_table_name with a valid table name.""" + table_name = " MyTable " + normalized = normalize_table_name(table_name) + assert normalized == "MyTable" + + def test_normalize_table_name_invalid(self): + """Test normalize_table_name with an invalid table name.""" + with pytest.raises(ValueError): + normalize_table_name(" ") + with pytest.raises(ValueError): + normalize_table_name(123) + + def test_safe_url(self): + """Test safe_url to ensure the password is hidden.""" + url_with_pwd = "postgresql://user:secret@localhost/db" + safe = safe_url(url_with_pwd) + assert "secret" not in safe + assert "*****" in safe + # URL without a password should remain unchanged + url_no_pwd = "postgresql://user@localhost/db" + assert safe_url(url_no_pwd) == url_no_pwd + + def test_index_name(self): + """Test index_name generates a valid index name.""" + table = "mytable" + columns = ["col1", "col2"] + ix = index_name(table, columns) + # Check that the index name starts with "ix_mytable_" and has proper length. + assert ix.startswith("ix_mytable_") + key_part = ix.split("_")[-1] + assert len(key_part) == 16 + + def test_pad_chunk_columns(self): + """Test pad_chunk_columns ensures each record has the required columns.""" + chunk = [{"a": 1}, {"b": 2}, {"a": 3, "c": 4}] + columns = ["a", "b", "c"] + padded = pad_chunk_columns(chunk, columns) + for record in padded: + for col in columns: + assert col in record + # Check that missing columns are padded with None. + assert padded[0]["b"] is None + assert padded[0]["c"] is None + assert padded[1]["a"] is None + assert padded[1]["c"] is None + + def test_normalize_column_name_long(self): + """Test that normalize_column_name trims long names to fit 63 chars and ensures utf-8 byte length is less than 64.""" + long_name = "a" * 100 + normalized = normalize_column_name(long_name) + # For ASCII all characters are 1 byte, so should be exactly 63 characters + assert len(normalized) == 63 + + def test_normalize_column_name_non_string(self): + """Test that normalize_column_name raises ValueError for non-string inputs.""" + with pytest.raises(ValueError): + normalize_column_name(12345) + + def test_make_sqlite_url_timeout_only(self): + """Test make_sqlite_url with only the timeout parameter provided.""" + url = make_sqlite_url("test_timeout.db", timeout=10) + assert url.startswith("sqlite:///file:test_timeout.db?") + query = url.split("?", 1)[1] + params = parse_qs(query) + assert params.get("timeout") == ["10"] + assert params.get("uri") == ["true"] + + def test_pad_chunk_columns_empty_chunk(self): + """Test that pad_chunk_columns returns an empty list when provided an empty chunk.""" + chunk = [] + columns = ["a", "b"] + padded = pad_chunk_columns(chunk, columns) + assert padded == [] + + def test_index_name_empty_columns(self): + """Test index_name with an empty list of columns produces a valid index name.""" + ix = index_name("table", []) + expected_hash = sha1("".encode("utf-8")).hexdigest()[:16] + expected = "ix_table_" + expected_hash + assert ix == expected + + def test_result_iter_next_after_iteration(self): + """Test that calling next() after exhausting the ResultIter raises StopIteration consistently.""" + DummyRow = namedtuple("DummyRow", ["a", "b"]) + rows = [DummyRow(1, "x")] + class DummyCursorForNextTest: + def __init__(self, rows): + self.rows = rows + self.index = 0 + def keys(self): + return list(self.rows[0]._fields) if self.rows else [] + def fetchall(self): + if self.index == 0: + self.index = len(self.rows) + return self.rows + return [] + def close(self): + self.closed = True + cursor = DummyCursorForNextTest(rows) + iter_obj = ResultIter(cursor, row_type=OrderedDict) + # Exhaust the iterator + list(iter_obj) + with pytest.raises(StopIteration): + next(iter_obj) + + def test_iter_result_proxy_empty_initially(self): + pass + def test_result_iter_calls_close(self): + """Test that ResultIter.close() is called after iteration is exhausted.""" + DummyRow = namedtuple("DummyRow", ["a"]) + rows = [DummyRow(1)] + closed_flag = {"closed": False} + class DummyCursorForClose: + def __init__(self, rows): + self.rows = rows + self.index = 0 + def keys(self): + return list(self.rows[0]._fields) if self.rows else [] + def fetchall(self): + if self.index == 0: + self.index = len(self.rows) + return self.rows + return [] + def close(self): + closed_flag["closed"] = True + cursor = DummyCursorForClose(rows) + iter_obj = ResultIter(cursor, row_type=OrderedDict) + # Force the iterator to exhaust + list(iter_obj) + assert closed_flag["closed"] is True + + def test_normalize_column_name_multibyte(self): + """Test that normalize_column_name trims multibyte names to adhere to the utf-8 byte limit.""" + # Use a multibyte character (e.g., "é": 2 bytes in utf-8); 70 instances + multi_byte_name = "é" * 70 + normalized = normalize_column_name(multi_byte_name) + # Check that the utf-8 encoding of normalized is less than 64 bytes + assert len(normalized.encode("utf-8")) < 64 + + def test_convert_row_invalid(self): + """Test convert_row with an object that does not have a _fields attribute.""" + class NoFields: + pass + with pytest.raises(AttributeError): + convert_row(OrderedDict, NoFields()) + """Test that iter_result_proxy returns an empty iterator when no rows are fetched.""" + rp = DummyResultProxy(chunks=[]) + iterator = iter_result_proxy(rp) + collected = list(iterator) + assert collected == [] +# End of tests for dataset/util.py \ No newline at end of file