diff --git a/tests/db/test_db.py b/tests/db/test_db.py index fc75b20..8e33584 100644 --- a/tests/db/test_db.py +++ b/tests/db/test_db.py @@ -1,5 +1,6 @@ import datetime import os +import random from typing import Optional from zoneinfo import ZoneInfo @@ -12,8 +13,7 @@ import blitzortung from blitzortung.service.general import create_time_interval -image = "postgres:16-alpine" -image = "postgis/postgis:16-3.5" +image = "postgis/postgis:18-3.6" postgres = PostgresContainer(image) @@ -99,26 +99,30 @@ def test_full_table_name(self, base): def test_get_timezone(self, base): assert_that(base.get_timezone()).is_equal_to(datetime.timezone.utc) - def test_fix_timezone(self, base): + @pytest.fixture + def cet_timezone(self): + return ZoneInfo("Europe/Berlin") + + def test_fix_timezone(self, base, cet_timezone): assert_that(base.fix_timezone(None)).is_none() - time = datetime.datetime(2013, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("CET")) + time = datetime.datetime(2013, 1, 1, 12, 0, 0, tzinfo=cet_timezone) utc_time = datetime.datetime(2013, 1, 1, 11, 0, 0, tzinfo=ZoneInfo("UTC")) assert_that(base.fix_timezone(time)).is_equal_to(utc_time) - def test_from_bare_utc_to_timezone(self, base): - base.set_timezone(ZoneInfo("CET")) + def test_from_bare_utc_to_timezone(self, base, cet_timezone): + base.set_timezone(cet_timezone) - time = datetime.datetime(2013, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("CET")) + time = datetime.datetime(2013, 1, 1, 12, 0, 0, tzinfo=cet_timezone) utc_time = datetime.datetime(2013, 1, 1, 11, 0, 0) assert_that(base.from_bare_utc_to_timezone(utc_time)).is_equal_to(time) - def test_from_timezone_to_bare_utc(self, base): - base.set_timezone(ZoneInfo("CET")) + def test_from_timezone_to_bare_utc(self, base, cet_timezone): + base.set_timezone(cet_timezone) - time = datetime.datetime(2013, 1, 1, 12, 0, 0, tzinfo=ZoneInfo("CET")) + time = datetime.datetime(2013, 1, 1, 12, 0, 0, tzinfo=cet_timezone) utc_time = datetime.datetime(2013, 1, 1, 11, 0, 0) assert_that(base.from_timezone_to_bare_utc(time)).is_equal_to(utc_time) @@ -274,6 +278,36 @@ def test_get_latest_time_with_region_mismatch(strikes, strike_factory, time_inte assert result is None +def test_bench_insert(strike_factory, strikes, benchmark): + strike = strike_factory(11, 49) + def insert(): + strikes.insert(strike, 1) + + benchmark.pedantic(insert, args=(), rounds=10, iterations=20) + +def test_bench_select(strike_factory, strikes, benchmark): + for i in range(100): + strike = strike_factory(11 + random.randrange(-100, 100, 1) / 100, 49 + random.randrange(-100, 100, 1) / 100) + strikes.insert(strike, 1) + + def select(): + result = strikes.select() + assert len(list(result)) == 100 + + benchmark.pedantic(select, args=(), rounds=10, iterations=100) + +def test_bench_select_grid(strike_factory, grid_factory, strikes, time_interval, benchmark): + for i in range(100): + strike = strike_factory(11 + random.randrange(-100, 100, 1) / 100, 49 + random.randrange(-100, 100, 1) / 100) + strikes.insert(strike, 1) + + grid = grid_factory.get_for(5000) + + def select(): + result = strikes.select_grid(grid, 0, time_interval=time_interval) + assert len(list(result)) <= 100 + + benchmark.pedantic(select, args=(), rounds=10, iterations=100) @pytest.mark.parametrize("raster_size,expected", [ (100000, (1, 1, 1, 0)),