From d09184ba6c17b79412b702dd35b28e740d994a29 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 13:48:59 -0800 Subject: [PATCH 01/10] Factor out wf1 module --- .../src/app/tests/fba/test_fba_endpoint.py | 5 +- .../app/tests/morecast_v2/test_forecasts.py | 8 +- backend/packages/wps-shared/pyproject.toml | 4 + .../wps-shared/src/wps_shared/tests/common.py | 10 +- .../tests/wildfire_one/test_wildfire_one.py | 26 +- .../src/wps_shared/wildfire_one/wfwx_api.py | 318 +++++++++--- .../wildfire_one/wildfire_fetchers.py | 238 --------- backend/packages/wps-wf1/README.md | 0 backend/packages/wps-wf1/pyproject.toml | 18 + .../packages/wps-wf1/src/wps_wf1/__init__.py | 12 + .../wps-wf1/src/wps_wf1/cache_protocol.py | 11 + .../wps-wf1/src/wps_wf1/query_builders.py | 128 +++++ .../wps-wf1/src/wps_wf1/tests/conftest.py | 43 ++ .../src/wps_wf1/tests/test_query_builders.py | 54 ++ .../src/wps_wf1/tests/test_wfwx_client.py | 483 ++++++++++++++++++ .../wps-wf1/src/wps_wf1/wfwx_client.py | 155 ++++++ .../wps-wf1/src/wps_wf1/wfwx_settings.py | 10 + backend/pytest.ini | 2 + backend/uv.lock | 14 + 19 files changed, 1208 insertions(+), 331 deletions(-) delete mode 100644 backend/packages/wps-shared/src/wps_shared/wildfire_one/wildfire_fetchers.py create mode 100644 backend/packages/wps-wf1/README.md create mode 100644 backend/packages/wps-wf1/pyproject.toml create mode 100644 backend/packages/wps-wf1/src/wps_wf1/__init__.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/cache_protocol.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/query_builders.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/tests/conftest.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/tests/test_query_builders.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/tests/test_wfwx_client.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/wfwx_client.py create mode 100644 backend/packages/wps-wf1/src/wps_wf1/wfwx_settings.py diff --git a/backend/packages/wps-api/src/app/tests/fba/test_fba_endpoint.py b/backend/packages/wps-api/src/app/tests/fba/test_fba_endpoint.py index 108c1fcfc5..2ef8bb18f3 100644 --- a/backend/packages/wps-api/src/app/tests/fba/test_fba_endpoint.py +++ b/backend/packages/wps-api/src/app/tests/fba/test_fba_endpoint.py @@ -1,4 +1,3 @@ -import asyncio import json import math from collections import namedtuple @@ -10,10 +9,8 @@ from aiohttp import ClientSession from app.tests import get_complete_filename from fastapi.testclient import TestClient - from wps_shared.db.models.auto_spatial_advisory import ( AdvisoryHFIWindSpeed, - AdvisoryTPIStats, RunParameters, SFMSFuelType, TPIClassEnum, @@ -24,7 +21,6 @@ FireZoneHFIStats, HFIStatsResponse, HfiThreshold, - LatestSFMSRunParameterRangeResponse, SFMSRunParameter, ) from wps_shared.tests.common import default_mock_client_get @@ -281,6 +277,7 @@ def client(): yield test_client +@patch("app.routers.fba.get_auth_header", mock_get_auth_header) @pytest.mark.usefixtures("mock_jwt_decode") @pytest.mark.parametrize( "status, expected_fire_centers", [(200, "test_fba_endpoint_fire_centers.json")] diff --git a/backend/packages/wps-api/src/app/tests/morecast_v2/test_forecasts.py b/backend/packages/wps-api/src/app/tests/morecast_v2/test_forecasts.py index a8ce21210b..19973160cb 100644 --- a/backend/packages/wps-api/src/app/tests/morecast_v2/test_forecasts.py +++ b/backend/packages/wps-api/src/app/tests/morecast_v2/test_forecasts.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import pytest from math import isclose +import app from wps_shared.db.models.morecast_v2 import MorecastForecastRecord from app.morecast_v2.forecasts import ( actual_exists, @@ -242,8 +243,13 @@ def test_construct_wf1_forecast_update(): @pytest.mark.anyio @patch("aiohttp.ClientSession.get") @patch("app.morecast_v2.forecasts.get_forecasts_for_stations_by_date_range", return_value=[station_1_daily_from_wf1]) -async def test_construct_wf1_forecasts_new(_, mock_get): +async def test_construct_wf1_forecasts_new(_, mock_get, monkeypatch: pytest.MonkeyPatch): + async def mock_get_auth_header(_): + return {} + + monkeypatch.setattr(app.morecast_v2.forecasts, "get_no_cache_auth_header", mock_get_auth_header) result = await construct_wf1_forecasts(mock_get, [morecast_input_1, morecast_input_2], wfwx_weather_stations, "user") + assert len(result) == 2 # existing forecast assert_wf1_forecast(result[0], morecast_input_1, station_1_daily_from_wf1.forecast_id, station_1_daily_from_wf1.created_by, station_1_url, "1") diff --git a/backend/packages/wps-shared/pyproject.toml b/backend/packages/wps-shared/pyproject.toml index 9ae17e81c5..fe8ce18613 100644 --- a/backend/packages/wps-shared/pyproject.toml +++ b/backend/packages/wps-shared/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "asyncpg>=0.30.0,<1", "redis>=7.0.0,<8", "gdal==3.9.2", + "wps-wf1", ] [project.optional-dependencies] @@ -35,3 +36,6 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src/wps_shared"] + +[tool.uv.sources] +wps-wf1 = { workspace = true } diff --git a/backend/packages/wps-shared/src/wps_shared/tests/common.py b/backend/packages/wps-shared/src/wps_shared/tests/common.py index 224f93797a..bac99803cf 100644 --- a/backend/packages/wps-shared/src/wps_shared/tests/common.py +++ b/backend/packages/wps-shared/src/wps_shared/tests/common.py @@ -1,10 +1,11 @@ """Mock modules/classes""" +import json import logging import os -import json -from typing import Optional from contextlib import asynccontextmanager +from typing import Optional + from wps_shared.auth import ASA_TEST_IDIR_GUID from wps_shared.tests.fixtures.loader import FixtureFinder @@ -109,6 +110,11 @@ async def json(self) -> dict: """Return json response""" return self._json + def raise_for_status(self) -> None: + """Mimic aiohttp.ClientResponse.raise_for_status().""" + if 400 <= self.status: + raise Exception(f"HTTP {self.status}") + class DefaultMockAioSession: """Mock aiobotocore.session.AioSession""" diff --git a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py index 3064055eb7..bece959c26 100644 --- a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py +++ b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py @@ -1,13 +1,22 @@ """Unit testing for WFWX API code""" import asyncio -from unittest.mock import patch, AsyncMock +from unittest.mock import AsyncMock, patch + import pytest +import wps_shared.wildfire_one.wfwx_post_api from fastapi import HTTPException from pytest_mock import MockFixture - -from wps_shared.wildfire_one.query_builders import BuildQueryAllForecastsByAfterStart, BuildQueryAllHourliesByRange, BuildQueryDailiesByStationCode, BuildQueryStationGroups -from wps_shared.wildfire_one.wfwx_api import WFWXWeatherStation, get_wfwx_stations_from_station_codes +from wps_shared.wildfire_one.query_builders import ( + BuildQueryAllForecastsByAfterStart, + BuildQueryAllHourliesByRange, + BuildQueryDailiesByStationCode, + BuildQueryStationGroups, +) +from wps_shared.wildfire_one.wfwx_api import ( + WFWXWeatherStation, + get_wfwx_stations_from_station_codes, +) from wps_shared.wildfire_one.wfwx_post_api import post_forecasts @@ -94,8 +103,15 @@ async def run_test(): @pytest.mark.anyio @patch("wps_shared.wildfire_one.wfwx_post_api.ClientSession") -async def test_wf1_post_failure(mock_client): +async def test_wf1_post_failure(mock_client, monkeypatch: pytest.MonkeyPatch): """Verifies that posting to WF1 raises an exception upon failure""" + + async def mock_get_auth_header(_): + return {} + + monkeypatch.setattr( + wps_shared.wildfire_one.wfwx_post_api, "get_auth_header", mock_get_auth_header + ) mock_client.post.return_value.__aenter__.return_value = AsyncMock(status=400) with pytest.raises(HTTPException): await post_forecasts(mock_client, []) diff --git a/backend/packages/wps-shared/src/wps_shared/wildfire_one/wfwx_api.py b/backend/packages/wps-shared/src/wps_shared/wildfire_one/wfwx_api.py index f307e5f68b..dc5bb69dd9 100644 --- a/backend/packages/wps-shared/src/wps_shared/wildfire_one/wfwx_api.py +++ b/backend/packages/wps-shared/src/wps_shared/wildfire_one/wfwx_api.py @@ -1,59 +1,77 @@ """This module contains methods for retrieving information from the WFWX Fireweather API.""" +import asyncio +import logging import math -from typing import List, Optional, Final, AsyncGenerator from datetime import datetime -import logging -import asyncio +from typing import AsyncGenerator, Dict, Final, List, Optional, Tuple + from aiohttp import ClientSession, TCPConnector +from wps_wf1.query_builders import ( + BuildQuery, + BuildQueryAllForecastsByAfterStart, + BuildQueryAllHourliesByRange, + BuildQueryByStationCode, + BuildQueryDailiesByStationCode, + BuildQueryStationGroups, + BuildQueryStations, +) +from wps_wf1.wfwx_client import WfwxClient +from wps_wf1.wfwx_settings import WfwxSettings + from wps_shared import config from wps_shared.data.ecodivision_seasons import EcodivisionSeasons from wps_shared.db.crud.hfi_calc import get_fire_centre_station_codes -from wps_shared.db.models.observations import HourlyActual +from wps_shared.db.crud.stations import _get_noon_date from wps_shared.db.models.forecasts import NoonForecast +from wps_shared.db.models.observations import HourlyActual +from wps_shared.schemas.fba import FireCentre from wps_shared.schemas.morecast_v2 import StationDailyFromWF1 from wps_shared.schemas.observations import WeatherStationHourlyReadings -from wps_shared.schemas.fba import FireCentre -from wps_shared.schemas.stations import WeatherStation, WeatherVariables +from wps_shared.schemas.stations import ( + DetailedWeatherStationProperties, + GeoJsonDetailedWeatherStation, + WeatherStation, + WeatherStationGeometry, + WeatherVariables, +) +from wps_shared.utils.redis import create_redis from wps_shared.wildfire_one.schema_parsers import ( WF1RecordTypeEnum, WFWXWeatherStation, + dailies_list_mapper, fire_center_mapper, + parse_hourly, + parse_hourly_actual, parse_noon_forecast, parse_station, - parse_hourly_actual, station_list_mapper, unique_weather_stations_mapper, weather_indeterminate_list_mapper, weather_station_group_mapper, wfwx_station_list_mapper, - dailies_list_mapper, -) -from wps_shared.wildfire_one.query_builders import ( - BuildQueryAllForecastsByAfterStart, - BuildQueryStations, - BuildQueryAllHourliesByRange, - BuildQueryByStationCode, - BuildQueryDailiesByStationCode, - BuildQueryStationGroups, ) from wps_shared.wildfire_one.util import is_station_valid -from wps_shared.wildfire_one.wildfire_fetchers import ( - fetch_access_token, - fetch_detailed_geojson_stations, - fetch_paged_response_generator, - fetch_hourlies, - fetch_raw_dailies_for_all_stations, - fetch_stations_by_group_id, -) logger = logging.getLogger(__name__) +def create_wps_wf1_client(session: ClientSession): + wfwx_settings = WfwxSettings( + base_url=config.get("WFWX_BASE_URL"), + auth_url=config.get("WFWX_AUTH_URL"), + user=config.get("WFWX_USER"), + secret=config.get("WFWX_SECRET"), + ) + wfwx_client = WfwxClient(session=session, settings=wfwx_settings, cache=create_redis()) + return wfwx_client + + async def get_auth_header(session: ClientSession) -> dict: """Get WFWX auth header""" + wfwx_client = create_wps_wf1_client(session) # Fetch access token - token = await fetch_access_token(session) + token = await wfwx_client.fetch_access_token(int(config.get("REDIS_AUTH_CACHE_EXPIRY", 600))) # Construct the header. header = {"Authorization": f"Bearer {token['access_token']}"} return header @@ -74,12 +92,19 @@ async def get_stations_by_codes(station_codes: List[int]) -> List[WeatherStation with EcodivisionSeasons(",".join([str(code) for code in station_codes])) as eco_division: async with ClientSession() as session: header = await get_auth_header(session) + wfwx_client = create_wps_wf1_client(session) stations = [] # 1 week seems a reasonable period to cache stations for. - redis_station_cache_expiry: Final = int(config.get("REDIS_STATION_CACHE_EXPIRY", 604800)) + redis_station_cache_expiry: Final = int( + config.get("REDIS_STATION_CACHE_EXPIRY", 604800) + ) # Iterate through "raw" station data. - iterator = fetch_paged_response_generator( - session, header, BuildQueryByStationCode(station_codes), "stations", use_cache=True, cache_expiry_seconds=redis_station_cache_expiry + iterator = wfwx_client.fetch_paged_response_generator( + header, + BuildQueryByStationCode(station_codes), + "stations", + use_cache=True, + ttl=redis_station_cache_expiry, ) async for raw_station in iterator: # If the station is valid, add it to our list of stations. @@ -92,16 +117,57 @@ async def get_stations_by_codes(station_codes: List[int]) -> List[WeatherStation async def get_station_data(session: ClientSession, header: dict, mapper=station_list_mapper): """Get list of stations from WFWX Fireweather API.""" logger.info("Using WFWX to retrieve station list") + wfwx_client = create_wps_wf1_client(session) # 1 week seems a reasonable period to cache stations for. redis_station_cache_expiry: Final = int(config.get("REDIS_STATION_CACHE_EXPIRY", 604800)) # Iterate through "raw" station data. - raw_stations = fetch_paged_response_generator(session, header, BuildQueryStations(), "stations", use_cache=True, cache_expiry_seconds=redis_station_cache_expiry) + raw_stations = wfwx_client.fetch_paged_response_generator( + header, + BuildQueryStations(), + "stations", + use_cache=True, + ttl=redis_station_cache_expiry, + ) # Map list of stations into desired shape stations = await mapper(raw_stations) logger.debug("total stations: %d", len(stations)) return stations +async def get_detailed_geojson_stations( + session: ClientSession, headers: dict, query_builder: BuildQuery +) -> Tuple[Dict[int, GeoJsonDetailedWeatherStation], Dict[str, int]]: + """Fetch and marshall geojson station data""" + stations = {} + id_to_code_map = {} + # 1 week seems a reasonable period to cache stations for. + redis_station_cache_expiry: Final = int(config.get("REDIS_STATION_CACHE_EXPIRY", 604800)) + wfwx_client = create_wps_wf1_client(session) + # Put the stations in a nice dictionary. + async for raw_station in wfwx_client.fetch_paged_response_generator( + headers, query_builder, "stations", True, redis_station_cache_expiry + ): + station_code = raw_station.get("stationCode") + station_status = raw_station.get("stationStatus", {}).get("id") + # Because we can't filter on status in the RSQL, we have to manually exclude stations that are + # not active. + if is_station_valid(raw_station): + id_to_code_map[raw_station.get("id")] = station_code + geojson_station = GeoJsonDetailedWeatherStation( + properties=DetailedWeatherStationProperties( + code=station_code, name=raw_station.get("displayLabel") + ), + geometry=WeatherStationGeometry( + coordinates=[raw_station.get("longitude"), raw_station.get("latitude")] + ), + ) + stations[station_code] = geojson_station + else: + logger.debug("station %s, status %s", station_code, station_status) + + return stations, id_to_code_map + + async def get_detailed_stations(time_of_interest: datetime): """ We do two things in parallel. @@ -114,10 +180,16 @@ async def get_detailed_stations(time_of_interest: datetime): async with ClientSession(connector=conn) as session: # Get the authentication header header = await get_auth_header(session) + noon_time_of_interest = _get_noon_date(time_of_interest) + wfwx_client = create_wps_wf1_client(session) # Fetch the daily (noon) values for all the stations - dailies_task = asyncio.create_task(fetch_raw_dailies_for_all_stations(session, header, time_of_interest)) + dailies_task = asyncio.create_task( + wfwx_client.fetch_raw_dailies_for_all_stations(header, noon_time_of_interest) + ) # Fetch all the stations - stations_task = asyncio.create_task(fetch_detailed_geojson_stations(session, header, BuildQueryStations())) + stations_task = asyncio.create_task( + get_detailed_geojson_stations(session, header, BuildQueryStations()) + ) # Await completion of concurrent tasks. dailies = await dailies_task @@ -129,7 +201,10 @@ async def get_detailed_stations(time_of_interest: datetime): station_code = id_to_code_map.get(station_id, None) if station_code: station = stations[station_code] - weather_variable = WeatherVariables(temperature=daily.get("temperature"), relative_humidity=daily.get("relativeHumidity")) + weather_variable = WeatherVariables( + temperature=daily.get("temperature"), + relative_humidity=daily.get("relativeHumidity"), + ) record_type = daily.get("recordType").get("id") if record_type in ["ACTUAL", "MANUAL"]: station.properties.observations = weather_variable @@ -143,8 +218,31 @@ async def get_detailed_stations(time_of_interest: datetime): return list(stations.values()) +async def get_hourly_for_station( + session, header, raw_station, start_timestamp, end_timestamp, eco_division, use_cache, ttl +): + wfwx_client = create_wps_wf1_client(session) + hourlies_json = await wfwx_client.fetch_hourlies( + raw_station, header, start_timestamp, end_timestamp, use_cache, ttl + ) + hourlies = [] + for hourly in hourlies_json["_embedded"]["hourlies"]: + # We only accept "ACTUAL" values + if hourly.get("hourlyMeasurementTypeCode", "").get("id") == "ACTUAL": + hourlies.append(parse_hourly(hourly)) + + return WeatherStationHourlyReadings( + values=hourlies, station=parse_station(raw_station, eco_division) + ) + + async def get_hourly_readings( - session: ClientSession, header: dict, station_codes: List[int], start_timestamp: datetime, end_timestamp: datetime, use_cache: bool = False + session: ClientSession, + header: dict, + station_codes: List[int], + start_timestamp: datetime, + end_timestamp: datetime, + use_cache: bool = False, ) -> List[WeatherStationHourlyReadings]: """Get the hourly readings for the list of station codes provided.""" # Create a list containing all the tasks to run in parallel. @@ -152,7 +250,10 @@ async def get_hourly_readings( # 1 week seems a reasonable period to cache stations for. redis_station_cache_expiry: Final = int(config.get("REDIS_STATION_CACHE_EXPIRY", 604800)) # Iterate through "raw" station data. - iterator = fetch_paged_response_generator(session, header, BuildQueryByStationCode(station_codes), "stations", True, redis_station_cache_expiry) + wfwx_client = create_wps_wf1_client(session) + iterator = wfwx_client.fetch_paged_response_generator( + header, BuildQueryByStationCode(station_codes), "stations", True, redis_station_cache_expiry + ) raw_stations = [] eco_division_key = "" # not ideal - we iterate through the stations twice. 1'st time to get the list of station codes, @@ -162,28 +263,51 @@ async def get_hourly_readings( raw_stations.append(raw_station) station_codes.add(raw_station.get("stationCode")) eco_division_key = ",".join(str(code) for code in station_codes) + cache_expiry_seconds: Final = int( + config.get("REDIS_HOURLIES_BY_STATION_CODE_CACHE_EXPIRY", 300) + ) with EcodivisionSeasons(eco_division_key) as eco_division: for raw_station in raw_stations: - task = asyncio.create_task(fetch_hourlies(session, raw_station, header, start_timestamp, end_timestamp, use_cache, eco_division)) + task = asyncio.create_task( + get_hourly_for_station( + session, + header, + raw_station, + start_timestamp, + end_timestamp, + eco_division, + use_cache, + cache_expiry_seconds, + ) + ) tasks.append(task) # Run the tasks concurrently, waiting for them all to complete. return await asyncio.gather(*tasks) -async def get_noon_forecasts_all_stations(session: ClientSession, header: dict, start_timestamp: datetime) -> List[NoonForecast]: +async def get_noon_forecasts_all_stations( + session: ClientSession, header: dict, start_timestamp: datetime +) -> List[NoonForecast]: """Get the noon forecasts for all stations.""" noon_forecasts: List[NoonForecast] = [] + wfwx_client = create_wps_wf1_client(session) # Iterate through "raw" forecast data. - forecasts_iterator = fetch_paged_response_generator(session, header, BuildQueryAllForecastsByAfterStart(math.floor(start_timestamp.timestamp() * 1000)), "dailies") + forecasts_iterator = wfwx_client.fetch_paged_response_generator( + header, + BuildQueryAllForecastsByAfterStart(math.floor(start_timestamp.timestamp() * 1000)), + "dailies", + ) forecasts = [] async for noon_forecast in forecasts_iterator: forecasts.append(noon_forecast) - stations: List[WFWXWeatherStation] = await get_station_data(session, header, mapper=wfwx_station_list_mapper) + stations: List[WFWXWeatherStation] = await get_station_data( + session, header, mapper=wfwx_station_list_mapper + ) station_code_dict = {station.wfwx_id: station.code for station in stations} @@ -199,21 +323,31 @@ async def get_noon_forecasts_all_stations(session: ClientSession, header: dict, return noon_forecasts -async def get_hourly_actuals_all_stations(session: ClientSession, header: dict, start_timestamp: datetime, end_timestamp: datetime) -> List[HourlyActual]: +async def get_hourly_actuals_all_stations( + session: ClientSession, header: dict, start_timestamp: datetime, end_timestamp: datetime +) -> List[HourlyActual]: """Get the hourly actuals for all stations.""" hourly_actuals: List[HourlyActual] = [] + wfwx_client = create_wps_wf1_client(session) # Iterate through "raw" hourlies data. - hourlies_iterator = fetch_paged_response_generator( - session, header, BuildQueryAllHourliesByRange(math.floor(start_timestamp.timestamp() * 1000), math.floor(end_timestamp.timestamp() * 1000)), "hourlies" + hourlies_iterator = wfwx_client.fetch_paged_response_generator( + header, + BuildQueryAllHourliesByRange( + math.floor(start_timestamp.timestamp() * 1000), + math.floor(end_timestamp.timestamp() * 1000), + ), + "hourlies", ) hourlies = [] async for hourly in hourlies_iterator: hourlies.append(hourly) - stations: List[WFWXWeatherStation] = await get_station_data(session, header, mapper=wfwx_station_list_mapper) + stations: List[WFWXWeatherStation] = await get_station_data( + session, header, mapper=wfwx_station_list_mapper + ) station_code_dict = {station.wfwx_id: station.code for station in stations} @@ -229,30 +363,9 @@ async def get_hourly_actuals_all_stations(session: ClientSession, header: dict, return hourly_actuals -async def get_daily_actuals_for_stations_between_dates(session: ClientSession, header: dict, start_datetime: datetime, end_datetime: datetime, stations: List[WeatherStation]): - """Get the daily actuals for each station.""" - - wfwx_station_ids = [station.wfwx_station_uuid for station in stations] - - start_timestamp = math.floor(start_datetime.timestamp() * 1000) - end_timestamp = math.floor(end_datetime.timestamp() * 1000) - - cache_expiry_seconds: Final = int(config.get("REDIS_DAILIES_BY_STATION_CODE_CACHE_EXPIRY", 300)) - use_cache = config.get("REDIS_USE") == "True" - - # Iterate through "raw" hourlies data. - dailies_iterator = fetch_paged_response_generator( - session, header, BuildQueryDailiesByStationCode(start_timestamp, end_timestamp, wfwx_station_ids), "dailies", use_cache=use_cache, cache_expiry_seconds=cache_expiry_seconds - ) - - dailies = [] - async for daily in dailies_iterator: - dailies.append(daily) - - return dailies - - -async def get_wfwx_stations_from_station_codes(session: ClientSession, header, station_codes: Optional[List[int]]) -> List[WFWXWeatherStation]: +async def get_wfwx_stations_from_station_codes( + session: ClientSession, header, station_codes: Optional[List[int]] +) -> List[WFWXWeatherStation]: """Return the WFWX station ids from WFWX API given a list of station codes.""" # All WFWX stations are requested because WFWX returns a malformed JSON response when too @@ -278,9 +391,22 @@ async def get_wfwx_stations_from_station_codes(session: ClientSession, header, s return requested_stations -async def get_raw_dailies_in_range_generator(session: ClientSession, header: dict, wfwx_station_ids: List[str], start_timestamp: int, end_timestamp: int) -> AsyncGenerator[dict, None]: +async def get_raw_dailies_in_range_generator( + session: ClientSession, + header: dict, + wfwx_station_ids: List[str], + start_timestamp: int, + end_timestamp: int, +) -> AsyncGenerator[dict, None]: """Get the raw dailies in range for a list of WFWX station ids.""" - return fetch_paged_response_generator(session, header, BuildQueryDailiesByStationCode(start_timestamp, end_timestamp, wfwx_station_ids), "dailies", True, 60) + wfwx_client = create_wps_wf1_client(session) + return wfwx_client.fetch_paged_response_generator( + header, + BuildQueryDailiesByStationCode(start_timestamp, end_timestamp, wfwx_station_ids), + "dailies", + True, + 60, + ) async def get_dailies_generator( @@ -307,14 +433,16 @@ async def get_dailies_generator( cache_expiry_seconds: Final = int(config.get("REDIS_DAILIES_BY_STATION_CODE_CACHE_EXPIRY", 300)) use_cache = check_cache is True and config.get("REDIS_USE") == "True" logger.info(f"Using cache: {use_cache}") + wfwx_client = create_wps_wf1_client(session) - dailies_iterator = fetch_paged_response_generator( - session, + dailies_iterator = wfwx_client.fetch_paged_response_generator( header, - BuildQueryDailiesByStationCode(timestamp_of_interest, end_timestamp_of_interest, wfwx_station_ids), + BuildQueryDailiesByStationCode( + timestamp_of_interest, end_timestamp_of_interest, wfwx_station_ids + ), "dailies", use_cache=use_cache, - cache_expiry_seconds=cache_expiry_seconds, + ttl=cache_expiry_seconds, ) return dailies_iterator @@ -330,12 +458,21 @@ async def get_fire_centers( async def get_dailies_for_stations_and_date( - session: ClientSession, header: dict, start_time_of_interest: datetime, end_time_of_interest: datetime, unique_station_codes: List[int], mapper=dailies_list_mapper + session: ClientSession, + header: dict, + start_time_of_interest: datetime, + end_time_of_interest: datetime, + unique_station_codes: List[int], + mapper=dailies_list_mapper, ): # get station information from the wfwx api - wfwx_stations = await get_wfwx_stations_from_station_codes(session, header, unique_station_codes) + wfwx_stations = await get_wfwx_stations_from_station_codes( + session, header, unique_station_codes + ) # get the dailies for all the stations - raw_dailies = await get_dailies_generator(session, header, wfwx_stations, start_time_of_interest, end_time_of_interest) + raw_dailies = await get_dailies_generator( + session, header, wfwx_stations, start_time_of_interest, end_time_of_interest + ) yesterday_dailies = await mapper(raw_dailies, WF1RecordTypeEnum.ACTUAL) @@ -343,13 +480,26 @@ async def get_dailies_for_stations_and_date( async def get_forecasts_for_stations_by_date_range( - session: ClientSession, header: dict, start_time_of_interest: datetime, end_time_of_interest: datetime, unique_station_codes: List[int], check_cache=True, mapper=dailies_list_mapper + session: ClientSession, + header: dict, + start_time_of_interest: datetime, + end_time_of_interest: datetime, + unique_station_codes: List[int], + check_cache=True, + mapper=dailies_list_mapper, ) -> List[StationDailyFromWF1]: # get station information from the wfwx api - wfwx_stations = await get_wfwx_stations_from_station_codes(session, header, unique_station_codes) + wfwx_stations = await get_wfwx_stations_from_station_codes( + session, header, unique_station_codes + ) # get the daily forecasts for all the stations in the date range raw_dailies = await get_dailies_generator( - session=session, header=header, wfwx_stations=wfwx_stations, time_of_interest=start_time_of_interest, end_time_of_interest=end_time_of_interest, check_cache=check_cache + session=session, + header=header, + wfwx_stations=wfwx_stations, + time_of_interest=start_time_of_interest, + end_time_of_interest=end_time_of_interest, + check_cache=check_cache, ) forecast_dailies = await mapper(raw_dailies, WF1RecordTypeEnum.FORECAST) @@ -367,7 +517,9 @@ async def get_daily_determinates_for_stations_and_date( check_cache: bool = True, ): # get station information from the wfwx api - wfwx_stations = await get_wfwx_stations_from_station_codes(session, header, unique_station_codes) + wfwx_stations = await get_wfwx_stations_from_station_codes( + session, header, unique_station_codes + ) # get the dailies for all the stations raw_dailies = await get_dailies_generator( session, header, wfwx_stations, start_time_of_interest, end_time_of_interest, check_cache @@ -382,7 +534,10 @@ async def get_station_groups(mapper=weather_station_group_mapper): """Get the station groups created by all users from Wild Fire One internal API.""" async with ClientSession() as session: header = await get_auth_header(session) - all_station_groups = fetch_paged_response_generator(session, header, BuildQueryStationGroups(), "stationGroups", use_cache=False) + wfwx_client = create_wps_wf1_client(session) + all_station_groups = wfwx_client.fetch_paged_response_generator( + header, BuildQueryStationGroups(), "stationGroups" + ) # Map list of stations into desired shape mapped_station_groups = await mapper(all_station_groups) logger.debug("total station groups: %d", len(mapped_station_groups)) @@ -394,8 +549,9 @@ async def get_stations_by_group_ids(group_ids: List[str], mapper=unique_weather_ stations_in_groups = [] async with ClientSession() as session: headers = await get_auth_header(session) + wfwx_client = create_wps_wf1_client(session) for group_id in group_ids: - stations = await fetch_stations_by_group_id(session, headers, group_id) + stations = await wfwx_client.fetch_stations_by_group_id(headers, group_id) stations_in_group = mapper(stations) stations_in_groups.extend(stations_in_group) return stations_in_groups diff --git a/backend/packages/wps-shared/src/wps_shared/wildfire_one/wildfire_fetchers.py b/backend/packages/wps-shared/src/wps_shared/wildfire_one/wildfire_fetchers.py deleted file mode 100644 index 3b7dd6ec8d..0000000000 --- a/backend/packages/wps-shared/src/wps_shared/wildfire_one/wildfire_fetchers.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Functions that request and marshall WFWX API responses into our schemas""" - -import math -import logging -from datetime import datetime -from typing import AsyncGenerator, Dict, Tuple, Final -import json -from urllib.parse import urlencode -from aiohttp.client import ClientSession, BasicAuth -from wps_shared.data.ecodivision_seasons import EcodivisionSeasons -from wps_shared.rocketchat_notifications import send_rocketchat_notification -from wps_shared.schemas.observations import WeatherStationHourlyReadings -from wps_shared.schemas.stations import DetailedWeatherStationProperties, GeoJsonDetailedWeatherStation, WeatherStationGeometry -from wps_shared.db.crud.stations import _get_noon_date -from wps_shared.wildfire_one.query_builders import BuildQuery -from wps_shared import config -from wps_shared.wildfire_one.schema_parsers import parse_hourly, parse_station -from wps_shared.wildfire_one.util import is_station_valid -from wps_shared.utils.redis import create_redis - -logger = logging.getLogger(__name__) - - -async def _fetch_cached_response(session: ClientSession, headers: dict, url: str, params: dict, cache_expiry_seconds: int): - cache = create_redis() - key = f"{url}?{urlencode(params)}" - try: - cached_json = cache.get(key) - except Exception as error: - cached_json = None - logger.error(error, exc_info=error) - if cached_json: - logger.info("redis cache hit %s", key) - response_json = json.loads(cached_json.decode()) - else: - logger.info("redis cache miss %s", key) - async with session.get(url, headers=headers, params=params) as response: - try: - response_json = await response.json() - except json.decoder.JSONDecodeError as error: - logger.error(error, exc_info=error) - text = await response.text() - logger.error("response.text() = %s", text) - send_rocketchat_notification(f"JSONDecodeError, response.text() = {text}", error) - raise - try: - if response.status == 200: - cache.set(key, json.dumps(response_json).encode(), ex=cache_expiry_seconds) - except Exception as error: - logger.error(error, exc_info=error) - return response_json - - -async def fetch_paged_response_generator( - session: ClientSession, headers: dict, query_builder: BuildQuery, content_key: str, use_cache: bool = False, cache_expiry_seconds: int = 86400 -) -> AsyncGenerator[dict, None]: - """Asynchronous generator for iterating through responses from the API. - The response is a paged response, but this generator abstracts that away. - """ - # We don't know how many pages until our first call - so we assume one page to start with. - total_pages = 1 - page_count = 0 - while page_count < total_pages: - # Build up the request URL. - url, params = query_builder.query(page_count) - logger.debug("loading page %d...", page_count) - if use_cache and config.get("REDIS_USE") == "True": - logger.info("Using cache") - # We've been told and configured to use the redis cache. - response_json = await _fetch_cached_response(session, headers, url, params, cache_expiry_seconds) - else: - logger.info("Not using cache") - async with session.get(url, headers=headers, params=params) as response: - response_json = await response.json() - logger.debug("done loading page %d.", page_count) - - # keep this code around for dumping responses to a json file - useful for when you're writing - # tests to grab actual responses to use in fixtures. - # import base64 - # TODO: write a beter way to make a temporary filename - # fname = 'thing_{}_{}.json'.format(base64.urlsafe_b64encode(url.encode()), random.randint(0, 1000)) - # with open(fname, 'w') as f: - # json.dump(response_json, f) - - # Update the total page count. - total_pages = response_json["page"]["totalPages"] if "page" in response_json else 1 - for response_object in response_json["_embedded"][content_key]: - yield response_object - # Keep track of our page count. - page_count = page_count + 1 - - -async def fetch_detailed_geojson_stations(session: ClientSession, headers: dict, query_builder: BuildQuery) -> Tuple[Dict[int, GeoJsonDetailedWeatherStation], Dict[str, int]]: - """Fetch and marshall geojson station data""" - stations = {} - id_to_code_map = {} - # 1 week seems a reasonable period to cache stations for. - redis_station_cache_expiry: Final = int(config.get("REDIS_STATION_CACHE_EXPIRY", 604800)) - # Put the stations in a nice dictionary. - async for raw_station in fetch_paged_response_generator(session, headers, query_builder, "stations", True, redis_station_cache_expiry): - station_code = raw_station.get("stationCode") - station_status = raw_station.get("stationStatus", {}).get("id") - # Because we can't filter on status in the RSQL, we have to manually exclude stations that are - # not active. - if is_station_valid(raw_station): - id_to_code_map[raw_station.get("id")] = station_code - geojson_station = GeoJsonDetailedWeatherStation( - properties=DetailedWeatherStationProperties(code=station_code, name=raw_station.get("displayLabel")), - geometry=WeatherStationGeometry(coordinates=[raw_station.get("longitude"), raw_station.get("latitude")]), - ) - stations[station_code] = geojson_station - else: - logger.debug("station %s, status %s", station_code, station_status) - - return stations, id_to_code_map - - -async def fetch_raw_dailies_for_all_stations(session: ClientSession, headers: dict, time_of_interest: datetime) -> list: - """Fetch the noon values(observations and forecasts) for a given time, for all weather stations.""" - # We don't know how many pages until our first call - so we assume one page to start with. - total_pages = 1 - page_count = 0 - hourlies = [] - while page_count < total_pages: - # Build up the request URL. - url, params = prepare_fetch_dailies_for_all_stations_query(time_of_interest, page_count) - # Get dailies - async with session.get(url, params=params, headers=headers) as response: - dailies_json = await response.json() - total_pages = dailies_json["page"]["totalPages"] - hourlies.extend(dailies_json["_embedded"]["dailies"]) - page_count = page_count + 1 - return hourlies - - -def prepare_fetch_hourlies_query(raw_station: dict, start_timestamp: datetime, end_timestamp: datetime): - """Prepare url and params to fetch hourly readings from the WFWX Fireweather API.""" - base_url = config.get("WFWX_BASE_URL") - - logger.debug("requesting historic data from %s to %s", start_timestamp, end_timestamp) - - # Prepare query params and query: - query_start_timestamp = math.floor(start_timestamp.timestamp() * 1000) - query_end_timestamp = math.floor(end_timestamp.timestamp() * 1000) - - station_id = raw_station["id"] - params = {"startTimestamp": query_start_timestamp, "endTimestamp": query_end_timestamp, "stationId": station_id} - endpoint = "/v1/hourlies/search/findHourliesByWeatherTimestampBetweenAndStationIdEqualsOrderByWeatherTimestampAsc" - url = f"{base_url}{endpoint}" - - return url, params - - -def prepare_fetch_dailies_for_all_stations_query(time_of_interest: datetime, page_count: int): - """Prepare url and params for fetching dailies(that's forecast and observations for noon) for all. - stations.""" - base_url = config.get("WFWX_BASE_URL") - noon_date = _get_noon_date(time_of_interest) - timestamp = int(noon_date.timestamp() * 1000) - # one could filter on recordType.id==FORECAST or recordType.id==ACTUAL but we want it all. - params = {"query": f"weatherTimestamp=={timestamp}", "page": page_count, "size": config.get("WFWX_MAX_PAGE_SIZE", 1000)} - endpoint = "/v1/dailies/rsql" - url = f"{base_url}{endpoint}" - logger.info("%s %s", url, params) - return url, params - - -async def fetch_hourlies( - session: ClientSession, raw_station: dict, headers: dict, start_timestamp: datetime, end_timestamp: datetime, use_cache: bool, eco_division: EcodivisionSeasons -) -> WeatherStationHourlyReadings: - """Fetch hourly weather readings for the specified time range for a give station""" - logger.debug("fetching hourlies for %s(%s)", raw_station["displayLabel"], raw_station["stationCode"]) - - url, params = prepare_fetch_hourlies_query(raw_station, start_timestamp, end_timestamp) - - cache_expiry_seconds: Final = int(config.get("REDIS_HOURLIES_BY_STATION_CODE_CACHE_EXPIRY", 300)) - - # Get hourlies - if use_cache and config.get("REDIS_USE") == "True": - hourlies_json = await _fetch_cached_response(session, headers, url, params, cache_expiry_seconds) - else: - async with session.get(url, params=params, headers=headers) as response: - hourlies_json = await response.json() - - hourlies = [] - for hourly in hourlies_json["_embedded"]["hourlies"]: - # We only accept "ACTUAL" values - if hourly.get("hourlyMeasurementTypeCode", "").get("id") == "ACTUAL": - hourlies.append(parse_hourly(hourly)) - - logger.debug("fetched %d hourlies for %s(%s)", len(hourlies), raw_station["displayLabel"], raw_station["stationCode"]) - - return WeatherStationHourlyReadings(values=hourlies, station=parse_station(raw_station, eco_division)) - - -async def fetch_access_token(session: ClientSession) -> dict: - """Fetch an access token for WFWX Fireweather API""" - logger.debug("fetching access token...") - password = config.get("WFWX_SECRET") - user = config.get("WFWX_USER") - auth_url = config.get("WFWX_AUTH_URL") - cache = create_redis() - # NOTE: Consider using a hashed version of the password as part of the key. - params = {"user": user} - key = f"{auth_url}?{urlencode(params)}" - try: - cached_json = cache.get(key) - except Exception as error: - cached_json = None - logger.error(error, exc_info=error) - if cached_json: - logger.info("redis cache hit %s", auth_url) - response_json = json.loads(cached_json.decode()) - else: - logger.info("redis cache miss %s", auth_url) - async with session.get(auth_url, auth=BasicAuth(login=user, password=password)) as response: - response_json = await response.json() - try: - if response.status == 200: - # We expire when the token expires, or 10 minutes, whichever is less. - # NOTE: only caching for 10 minutes right now, since we aren't handling cases - # where the token is invalidated. - redis_auth_cache_expiry: Final = int(config.get("REDIS_AUTH_CACHE_EXPIRY", 600)) - expires = min(response_json["expires_in"], redis_auth_cache_expiry) - cache.set(key, json.dumps(response_json).encode(), ex=expires) - except Exception as error: - logger.error(error, exc_info=error) - return response_json - - -async def fetch_stations_by_group_id(session: ClientSession, headers: dict, group_id: str): - logger.debug(f"Fetching stations for group {group_id}") - base_url = config.get("WFWX_BASE_URL") - url = f"{base_url}/v1/stationGroups/{group_id}/members" - - async with session.get(url, headers=headers) as response: - raw_stations = await response.json() - return raw_stations diff --git a/backend/packages/wps-wf1/README.md b/backend/packages/wps-wf1/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/packages/wps-wf1/pyproject.toml b/backend/packages/wps-wf1/pyproject.toml new file mode 100644 index 0000000000..049e41aef6 --- /dev/null +++ b/backend/packages/wps-wf1/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "wps-wf1" +version = "0.1.0" +description = "Wildfire Predictive Services Unit WF1 utils" +authors = [ + { name = "Darren Boss", email = "darren.boss@gov.bc.ca" } +] +requires-python = ">=3.12.3,<4.0" +dependencies = [ + "aiohttp>=3.13.2", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/wps_wf1"] diff --git a/backend/packages/wps-wf1/src/wps_wf1/__init__.py b/backend/packages/wps-wf1/src/wps_wf1/__init__.py new file mode 100644 index 0000000000..92d06689e1 --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/__init__.py @@ -0,0 +1,12 @@ + +"""Thin WFWX API client with optional caching. +""" +from wps_wf1.wfwx_settings import WfwxSettings +from wps_wf1.wfwx_client import WfwxClient +from wps_wf1.cache_protocol import CacheProtocol + +__all__ = [ + 'WfwxSettings', + 'WfwxClient', + 'CacheProtocol', +] diff --git a/backend/packages/wps-wf1/src/wps_wf1/cache_protocol.py b/backend/packages/wps-wf1/src/wps_wf1/cache_protocol.py new file mode 100644 index 0000000000..70b17f8cad --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/cache_protocol.py @@ -0,0 +1,11 @@ +from typing import Optional, Protocol + + +class CacheProtocol(Protocol): + """ + Interface for cache implementation in wps-wf1 package to support dependency injection + """ + + def get(self, key: str) -> Optional[bytes]: ... + + def set(self, key: str, value: bytes, ex: int) -> None: ... diff --git a/backend/packages/wps-wf1/src/wps_wf1/query_builders.py b/backend/packages/wps-wf1/src/wps_wf1/query_builders.py new file mode 100644 index 0000000000..8a0b8b3209 --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/query_builders.py @@ -0,0 +1,128 @@ +""" Query builder classes for making requests to WFWX API """ +from typing import List, Tuple +from abc import abstractmethod, ABC + +from wps_shared import config + + +class BuildQuery(ABC): + """ Base class for building query urls and params """ + + def __init__(self): + """ Initialize object """ + self.max_page_size = config.get('WFWX_MAX_PAGE_SIZE', 1000) + self.base_url = config.get('WFWX_BASE_URL') + + @abstractmethod + def query(self, page) -> Tuple[str, dict]: + """ Return query url and params """ + + +class BuildQueryStations(BuildQuery): + """ Class for building a url and RSQL params to request all active stations. """ + + def __init__(self): + """ Prepare filtering on active, test and project stations. """ + super().__init__() + self.param_query = None + # In conversation with Dana Hicks, on Apr 20, 2021 - Dana said to show active, test and project. + for status in ('ACTIVE', 'TEST', 'PROJECT'): + if self.param_query: + self.param_query += f',stationStatus.id=="{status}"' + else: + self.param_query = f'stationStatus.id=="{status}"' + + def query(self, page) -> Tuple[str, dict]: + """ Return query url and params with rsql query for all weather stations marked active. """ + params = {'size': self.max_page_size, 'sort': 'displayLabel', + 'page': page, 'query': self.param_query} + url = f'{self.base_url}/v1/stations' + return url, params + + +class BuildQueryByStationCode(BuildQuery): + """ Class for building a url and params to request a list of stations by code """ + + def __init__(self, station_codes: List[int]): + """ Initialize object """ + super().__init__() + self.querystring = '' + for code in station_codes: + if len(self.querystring) > 0: + self.querystring += ' or ' + self.querystring += f'stationCode=={code}' + + def query(self, page) -> Tuple[str, dict]: + """ Return query url and params for a list of stations """ + params = {'size': self.max_page_size, + 'sort': 'displayLabel', 'page': page, 'query': self.querystring} + url = f'{self.base_url}/v1/stations/rsql' + return url, params + + +class BuildQueryAllHourliesByRange(BuildQuery): + """ Builds query for requesting all hourlies in a time range""" + + def __init__(self, start_timestamp: int, end_timestamp: int): + """ Initialize object """ + super().__init__() + self.querystring: str = "weatherTimestamp >=" + \ + str(start_timestamp) + ";" + "weatherTimestamp <" + str(end_timestamp) + + def query(self, page) -> Tuple[str, dict]: + """ Return query url for hourlies between start_timestamp, end_timestamp""" + params = {'size': self.max_page_size, 'page': page, 'query': self.querystring} + url = f'{self.base_url}/v1/hourlies/rsql' + return url, params + + +class BuildQueryAllForecastsByAfterStart(BuildQuery): + """ Builds query for requesting all dailies in a time range""" + + def __init__(self, start_timestamp: int): + """ Initialize object """ + super().__init__() + self.querystring = f"weatherTimestamp >={start_timestamp};recordType.id == 'FORECAST'" + + def query(self, page) -> Tuple[str, dict]: + """ Return query url for dailies between start_timestamp, end_timestamp""" + params = {'size': self.max_page_size, 'page': page, 'query': self.querystring} + url = f'{self.base_url}/v1/dailies/rsql' + return url, params + + +class BuildQueryDailiesByStationCode(BuildQuery): + """ Builds query for requesting dailies in a time range for the station codes""" + + def __init__(self, start_timestamp: int, end_timestamp: int, station_ids: List[str]): + """ Initialize object """ + super().__init__() + self.start_timestamp = start_timestamp + self.end_timestamp = end_timestamp + self.station_ids = station_ids + + def query(self, page) -> Tuple[str, dict]: + """ Return query url for dailies between start_timestamp, end_timestamp""" + params = {'size': self.max_page_size, + 'page': page, + 'startingTimestamp': self.start_timestamp, + 'endingTimestamp': self.end_timestamp, + 'stationIds': self.station_ids} + url = (f'{self.base_url}/v1/dailies/search/findDailiesByStationIdIsInAndWeather' + + 'TimestampBetweenOrderByStationIdAscWeatherTimestampAsc') + return url, params + + +class BuildQueryStationGroups(BuildQuery): + """ Builds a query for requesting all station groups """ + + def __init__(self): + """ Initilize object. """ + super().__init__() + self.param_query = None + + def query(self, page) -> Tuple[str, dict]: + """ Return query url and params with query for all weather stations groups. """ + params = {'size': self.max_page_size, 'page': page, 'sort': 'groupOwnerUserId,asc'} + url = f'{self.base_url}/v1/stationGroups' + return url, params diff --git a/backend/packages/wps-wf1/src/wps_wf1/tests/conftest.py b/backend/packages/wps-wf1/src/wps_wf1/tests/conftest.py new file mode 100644 index 0000000000..aa7f6d5473 --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/tests/conftest.py @@ -0,0 +1,43 @@ +"""Global fixtures""" + +from unittest.mock import MagicMock + +import pytest +from aiohttp import ClientSession +from wps_wf1.cache_protocol import CacheProtocol +from wps_wf1.wfwx_settings import WfwxSettings + + +@pytest.fixture(autouse=True) +def mock_env(monkeypatch): + """Automatically mock environment variable""" + monkeypatch.setenv("BASE_URI", "https://python-test-base-uri") + monkeypatch.setenv("WFWX_USER", "user") + monkeypatch.setenv("WFWX_SECRET", "secret") + monkeypatch.setenv("WFWX_AUTH_URL", "https://wf1/pub/oauth2/v1/oauth/token") + monkeypatch.setenv("WFWX_BASE_URL", "https://wf1/wfwx") + monkeypatch.setenv("WFWX_MAX_PAGE_SIZE", "1000") + + +@pytest.fixture +def mock_session(): + """Mock ClientSession for unit tests""" + return MagicMock(spec=ClientSession) + + +@pytest.fixture +def mock_settings(): + """Mock WfwxSettings for unit tests""" + return WfwxSettings( + base_url="https://test.example.com", + auth_url="https://auth.example.com", + user="test_user", + secret="test_secret", + max_page_size=100, + ) + + +@pytest.fixture +def mock_cache(): + """Mock CacheProtocol for unit tests""" + return MagicMock(spec=CacheProtocol) diff --git a/backend/packages/wps-wf1/src/wps_wf1/tests/test_query_builders.py b/backend/packages/wps-wf1/src/wps_wf1/tests/test_query_builders.py new file mode 100644 index 0000000000..df0f87ff73 --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/tests/test_query_builders.py @@ -0,0 +1,54 @@ +from wps_wf1.query_builders import ( + BuildQueryAllForecastsByAfterStart, + BuildQueryAllHourliesByRange, + BuildQueryDailiesByStationCode, + BuildQueryStationGroups, +) + + +def test_build_all_hourlies_query(): + """Verifies the query builder returns the correct url and parameters""" + query_builder = BuildQueryAllHourliesByRange(0, 1) + result = query_builder.query(0) + assert result == ( + "https://wf1/wfwx/v1/hourlies/rsql", + {"size": "1000", "page": 0, "query": "weatherTimestamp >=0;weatherTimestamp <1"}, + ) + + +def test_build_forecasts_query(): + """Verifies the query builder returns the correct url and parameters""" + query_builder = BuildQueryAllForecastsByAfterStart(0) + result = query_builder.query(0) + assert result == ( + "https://wf1/wfwx/v1/dailies/rsql", + {"size": "1000", "page": 0, "query": "weatherTimestamp >=0;recordType.id == 'FORECAST'"}, + ) + + +def test_build_dailies_by_station_code(): + """Verifies the query builder returns the correct url and parameters for dailies by station code""" + query_builder = BuildQueryDailiesByStationCode(0, 1, ["1", "2"]) + result = query_builder.query(0) + assert result == ( + "https://wf1/wfwx/v1/dailies/search/" + + "findDailiesByStationIdIsInAndWeather" + + "TimestampBetweenOrderByStationIdAscWeatherTimestampAsc", + { + "size": "1000", + "page": 0, + "startingTimestamp": 0, + "endingTimestamp": 1, + "stationIds": ["1", "2"], + }, + ) + + +def test_build_station_groups_query(): + """Verifies the query builder returns the correct url and parameters for a station groups query""" + query_builder = BuildQueryStationGroups() + result = query_builder.query(0) + assert result == ( + "https://wf1/wfwx/v1/stationGroups", + {"size": "1000", "page": 0, "sort": "groupOwnerUserId,asc"}, + ) diff --git a/backend/packages/wps-wf1/src/wps_wf1/tests/test_wfwx_client.py b/backend/packages/wps-wf1/src/wps_wf1/tests/test_wfwx_client.py new file mode 100644 index 0000000000..bf302b8c64 --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/tests/test_wfwx_client.py @@ -0,0 +1,483 @@ +"""Unit tests for wfwx_client.py""" + +import json +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock +from urllib.parse import urlencode + +import pytest +from wps_wf1.query_builders import BuildQuery +from wps_wf1.wfwx_client import WfwxClient, _cache_key + + +class MockAsyncContextManager: + """Mock async context manager for aiohttp responses""" + + def __init__(self, response_data): + self.response_data = response_data + + async def __aenter__(self): + mock_response = AsyncMock() + mock_response.json.return_value = self.response_data + mock_response.raise_for_status.return_value = None + return mock_response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class TestCacheKey: + """Test cases for the _cache_key function""" + + def test_cache_key_generation(self): + """Test that cache key is generated correctly from URL and params""" + url = "https://example.com/api/data" + params = {"param1": "value1", "param2": "value2"} + + expected_key = f"{url}?{urlencode(params)}" + result = _cache_key(url, params) + + assert result == expected_key + + def test_cache_key_with_empty_params(self): + """Test cache key generation with empty parameters""" + url = "https://example.com/api/data" + params = {} + + expected_key = f"{url}?" + result = _cache_key(url, params) + + assert result == expected_key + + def test_cache_key_with_special_characters(self): + """Test cache key generation with special characters in parameters""" + url = "https://example.com/api/data" + params = {"query": "name=test&value=123", "filter": "active=true"} + + expected_key = f"{url}?{urlencode(params)}" + result = _cache_key(url, params) + + assert result == expected_key + + +class TestWfwxClient: + """Test cases for the WfwxClient class""" + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings, mock_cache): + """Create a WfwxClient instance with mocked dependencies""" + return WfwxClient(mock_session, mock_settings, mock_cache) + + def test_init(self, mock_session, mock_settings, mock_cache): + """Test WfwxClient initialization""" + client = WfwxClient(mock_session, mock_settings, mock_cache) + + assert client.session == mock_session + assert client.settings == mock_settings + assert client.cache == mock_cache + + def test_init_without_cache(self, mock_session, mock_settings): + """Test WfwxClient initialization without cache""" + client = WfwxClient(mock_session, mock_settings) + + assert client.session == mock_session + assert client.settings == mock_settings + assert client.cache is None + + +class TestWfwxClientGetJson: + """Test cases for the _get_json method""" + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings, mock_cache): + """Create a WfwxClient instance""" + return WfwxClient(mock_session, mock_settings, mock_cache) + + @pytest.mark.anyio + async def test_get_json_with_cache_hit(self, wfwx_client, mock_cache): + """Test _get_json returns cached data when available""" + url = "https://test.example.com/api/data" + headers = {"Authorization": "Bearer token"} + params = {"key": "value"} + cached_data = {"cached": True} + + # Setup cache to return data + mock_cache.get.return_value = json.dumps(cached_data).encode("utf-8") + + result = await wfwx_client._get_json(url, headers, params) + + assert result == cached_data + # Verify cache was checked but no HTTP request was made + mock_cache.get.assert_called_once() + wfwx_client.session.get.assert_not_called() + + @pytest.mark.anyio + async def test_get_json_with_cache_miss(self, wfwx_client, mock_cache): + """Test _get_json fetches data when not cached""" + url = "https://test.example.com/api/data" + headers = {"Authorization": "Bearer token"} + params = {"key": "value"} + response_data = {"data": "test"} + + # Setup cache to return None (no cached data) + mock_cache.get.return_value = None + + # Setup the session.get to return our mock context manager + wfwx_client.session.get.return_value = MockAsyncContextManager(response_data) + + result = await wfwx_client._get_json(url, headers, params) + + assert result == response_data + mock_cache.get.assert_called_once() + # Verify the data was cached + mock_cache.set.assert_called_once() + + @pytest.mark.anyio + async def test_get_json_without_cache(self, mock_session, mock_settings): + """Test _get_json when no cache is provided""" + client = WfwxClient(mock_session, mock_settings) + url = "https://test.example.com/api/data" + headers = {"Authorization": "Bearer token"} + params = {"key": "value"} + response_data = {"data": "test"} + + # Setup the session.get to return our mock context manager + mock_session.get.return_value = MockAsyncContextManager(response_data) + + result = await client._get_json(url, headers, params) + + assert result == response_data + # Verify HTTP request was made + mock_session.get.assert_called_once_with(url, headers=headers, params=params) + + @pytest.mark.anyio + async def test_get_json_with_use_cache_false(self, wfwx_client, mock_cache): + """Test _get_json respects use_cache=False parameter""" + url = "https://test.example.com/api/data" + headers = {"Authorization": "Bearer token"} + params = {"key": "value"} + response_data = {"data": "test"} + + # Setup the session.get to return our mock context manager + wfwx_client.session.get.return_value = MockAsyncContextManager(response_data) + + result = await wfwx_client._get_json(url, headers, params, use_cache=False) + + assert result == response_data + # Verify cache was not checked + mock_cache.get.assert_not_called() + + @pytest.mark.anyio + async def test_get_json_with_custom_ttl(self, wfwx_client, mock_cache): + """Test _get_json uses custom TTL when provided""" + url = "https://test.example.com/api/data" + headers = {"Authorization": "Bearer token"} + params = {"key": "value"} + response_data = {"data": "test"} + custom_ttl = 3600 + + # Setup cache to return None (no cached data) + mock_cache.get.return_value = None + + # Setup the session.get to return our mock context manager + wfwx_client.session.get.return_value = MockAsyncContextManager(response_data) + + result = await wfwx_client._get_json(url, headers, params, ttl=custom_ttl) + + assert result == response_data + # Verify the data was cached with custom TTL + mock_cache.set.assert_called_once() + call_args = mock_cache.set.call_args + assert call_args[1]["ex"] == custom_ttl + + +class TestWfwxClientFetchAccessToken: + """Test cases for the fetch_access_token method""" + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings, mock_cache): + """Create a WfwxClient instance""" + return WfwxClient(mock_session, mock_settings, mock_cache) + + @pytest.mark.anyio + async def test_fetch_access_token_with_cache_hit(self, wfwx_client, mock_cache): + """Test fetch_access_token returns cached token when available""" + cached_token = {"access_token": "cached_token", "expires_in": 3600} + + # Setup cache to return data + mock_cache.get.return_value = json.dumps(cached_token).encode("utf-8") + + result = await wfwx_client.fetch_access_token(3600) + + assert result == cached_token + mock_cache.get.assert_called_once() + wfwx_client.session.get.assert_not_called() + + @pytest.mark.anyio + async def test_fetch_access_token_with_cache_miss(self, wfwx_client, mock_cache): + """Test fetch_access_token fetches new token when not cached""" + token_response = {"access_token": "new_token", "expires_in": 7200} + + # Setup cache to return None (no cached data) + mock_cache.get.return_value = None + + # Setup the session.get to return our mock context manager + wfwx_client.session.get.return_value = MockAsyncContextManager(token_response) + + result = await wfwx_client.fetch_access_token(3600) + + assert result == token_response + # Verify the token was cached with min(expires_in, ttl) + mock_cache.set.assert_called_once() + call_args = mock_cache.set.call_args + assert call_args[1]["ex"] == 3600 # min(7200, 3600) + + @pytest.mark.anyio + async def test_fetch_access_token_without_cache(self, mock_session, mock_settings): + """Test fetch_access_token when no cache is provided""" + client = WfwxClient(mock_session, mock_settings) + token_response = {"access_token": "new_token", "expires_in": 3600} + + # Setup the session.get to return our mock context manager + mock_session.get.return_value = MockAsyncContextManager(token_response) + + result = await client.fetch_access_token(3600) + + assert result == token_response + # Verify HTTP request was made with correct auth + mock_session.get.assert_called_once() + call_args = mock_session.get.call_args + assert call_args[1]["auth"].login == "test_user" + assert call_args[1]["auth"].password == "test_secret" + + +class TestWfwxClientFetchPagedResponse: + """Test cases for the fetch_paged_response_generator method""" + + @pytest.fixture + def mock_query_builder(self): + """Mock BuildQuery""" + mock = MagicMock(spec=BuildQuery) + return mock + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings, mock_cache): + """Create a WfwxClient instance""" + return WfwxClient(mock_session, mock_settings, mock_cache) + + @pytest.mark.anyio + async def test_fetch_paged_response_generator_single_page( + self, wfwx_client, mock_query_builder + ): + """Test fetch_paged_response_generator with single page of results""" + headers = {"Authorization": "Bearer token"} + content_key = "items" + + # Setup query builder to return URL and params + mock_query_builder.query.return_value = ("https://test.example.com/api/items", {"page": 0}) + + # Setup response data + response_data = {"page": {"totalPages": 1}, "_embedded": {"items": [{"id": 1}, {"id": 2}]}} + + # Mock _get_json to return response data + wfwx_client._get_json = AsyncMock(return_value=response_data) + + # Test the generator + results = [] + async for item in wfwx_client.fetch_paged_response_generator( + headers, mock_query_builder, content_key + ): + results.append(item) + + assert len(results) == 2 + assert results[0] == {"id": 1} + assert results[1] == {"id": 2} + + @pytest.mark.anyio + async def test_fetch_paged_response_generator_multiple_pages( + self, wfwx_client, mock_query_builder + ): + """Test fetch_paged_response_generator with multiple pages""" + headers = {"Authorization": "Bearer token"} + content_key = "items" + + # Setup query builder to return different URLs for each page + mock_query_builder.query.side_effect = [ + ("https://test.example.com/api/items", {"page": 0}), + ("https://test.example.com/api/items", {"page": 1}), + ("https://test.example.com/api/items", {"page": 2}), + ] + + # Setup response data for each page + response_data_page_0 = { + "page": {"totalPages": 3}, + "_embedded": {"items": [{"id": 1}, {"id": 2}]}, + } + response_data_page_1 = { + "page": {"totalPages": 3}, + "_embedded": {"items": [{"id": 3}, {"id": 4}]}, + } + response_data_page_2 = {"page": {"totalPages": 3}, "_embedded": {"items": [{"id": 5}]}} + + # Setup the session.get to return different responses for each call + wfwx_client.session.get.side_effect = [ + MockAsyncContextManager(response_data_page_0), + MockAsyncContextManager(response_data_page_1), + MockAsyncContextManager(response_data_page_2), + ] + + # Test the generator + results = [] + async for item in wfwx_client.fetch_paged_response_generator( + headers, mock_query_builder, content_key + ): + results.append(item) + + assert len(results) == 5 + assert results == [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}] + + +class TestWfwxClientFetchRawDailies: + """Test cases for the fetch_raw_dailies_for_all_stations method""" + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings): + """Create a WfwxClient instance""" + return WfwxClient(mock_session, mock_settings) + + @pytest.mark.anyio + async def test_fetch_raw_dailies_single_page(self, wfwx_client): + """Test fetch_raw_dailies_for_all_stations with single page""" + headers = {"Authorization": "Bearer token"} + time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + timestamp = int(time_of_interest.timestamp() * 1000) + + # Setup response data + response_data = { + "page": {"totalPages": 1}, + "_embedded": {"dailies": [{"id": 1, "temp": 20}, {"id": 2, "temp": 22}]}, + } + + # Setup the session.get to return our mock context manager + wfwx_client.session.get.return_value = MockAsyncContextManager(response_data) + + result = await wfwx_client.fetch_raw_dailies_for_all_stations(headers, time_of_interest) + + assert len(result) == 2 + assert result[0] == {"id": 1, "temp": 20} + assert result[1] == {"id": 2, "temp": 22} + + # Verify the correct URL and parameters were used + expected_url = f"{wfwx_client.settings.base_url}/v1/dailies/rsql" + expected_params = { + "query": f"weatherTimestamp=={timestamp}", + "page": 0, + "size": wfwx_client.settings.max_page_size, + } + + wfwx_client.session.get.assert_called_once_with( + expected_url, params=expected_params, headers=headers + ) + + @pytest.mark.anyio + async def test_fetch_raw_dailies_multiple_pages(self, wfwx_client): + """Test fetch_raw_dailies_for_all_stations with multiple pages""" + headers = {"Authorization": "Bearer token"} + time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + + # Setup response data for multiple pages + response_data_page_0 = {"page": {"totalPages": 2}, "_embedded": {"dailies": [{"id": 1}]}} + response_data_page_1 = { + "page": {"totalPages": 2}, + "_embedded": {"dailies": [{"id": 2}, {"id": 3}]}, + } + + # Setup the session.get to return different responses for each call + wfwx_client.session.get.side_effect = [ + MockAsyncContextManager(response_data_page_0), + MockAsyncContextManager(response_data_page_1), + ] + + result = await wfwx_client.fetch_raw_dailies_for_all_stations(headers, time_of_interest) + + assert len(result) == 3 + assert result == [{"id": 1}, {"id": 2}, {"id": 3}] + + +class TestWfwxClientFetchHourlies: + """Test cases for hourlies-related methods""" + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings, mock_cache): + """Create a WfwxClient instance""" + return WfwxClient(mock_session, mock_settings, mock_cache) + + def test_prepare_fetch_hourlies_query(self, wfwx_client): + """Test prepare_fetch_hourlies_query generates correct URL and parameters""" + raw_station = {"id": "station123"} + start_datetime = datetime(2023, 1, 1, 0, 0, 0) + end_datetime = datetime(2023, 1, 1, 23, 59, 59) + + start_ts = int(start_datetime.timestamp() * 1000) + end_ts = int(end_datetime.timestamp() * 1000) + + url, params = wfwx_client.prepare_fetch_hourlies_query( + raw_station, start_datetime, end_datetime + ) + + expected_url = f"{wfwx_client.settings.base_url}/v1/hourlies/search/findHourliesByWeatherTimestampBetweenAndStationIdEqualsOrderByWeatherTimestampAsc" + expected_params = { + "startTimestamp": start_ts, + "endTimestamp": end_ts, + "stationId": "station123", + } + + assert url == expected_url + assert params == expected_params + + @pytest.mark.anyio + async def test_fetch_hourlies(self, wfwx_client): + """Test fetch_hourlies calls _get_json with correct parameters""" + raw_station = {"id": "station123"} + headers = {"Authorization": "Bearer token"} + start_datetime = datetime(2023, 1, 1, 0, 0, 0) + end_datetime = datetime(2023, 1, 1, 23, 59, 59) + use_cache = True + ttl = 3600 + + response_data = {"hourlies": [{"temp": 20}, {"temp": 22}]} + wfwx_client._get_json = AsyncMock(return_value=response_data) + + result = await wfwx_client.fetch_hourlies( + raw_station, headers, start_datetime, end_datetime, use_cache, ttl + ) + + assert result == response_data + wfwx_client._get_json.assert_called_once() + + +class TestWfwxClientFetchStations: + """Test cases for the fetch_stations_by_group_id method""" + + @pytest.fixture + def wfwx_client(self, mock_session, mock_settings): + """Create a WfwxClient instance""" + return WfwxClient(mock_session, mock_settings) + + @pytest.mark.anyio + async def test_fetch_stations_by_group_id(self, wfwx_client): + """Test fetch_stations_by_group_id fetches stations for a group""" + headers = {"Authorization": "Bearer token"} + group_id = "group123" + response_data = {"_embedded": {"stations": [{"id": 1}, {"id": 2}]}} + + # Setup the session.get to return our mock context manager + wfwx_client.session.get.return_value = MockAsyncContextManager(response_data) + + result = await wfwx_client.fetch_stations_by_group_id(headers, group_id) + + assert result == response_data + + # Verify the correct URL was used + expected_url = f"{wfwx_client.settings.base_url}/v1/stationGroups/{group_id}/members" + wfwx_client.session.get.assert_called_once_with(expected_url, headers=headers) diff --git a/backend/packages/wps-wf1/src/wps_wf1/wfwx_client.py b/backend/packages/wps-wf1/src/wps_wf1/wfwx_client.py new file mode 100644 index 0000000000..05fcaffcfb --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/wfwx_client.py @@ -0,0 +1,155 @@ +import json +import logging +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, Optional +from urllib.parse import urlencode + +from aiohttp import BasicAuth, ClientSession + +from wps_wf1.cache_protocol import CacheProtocol +from wps_wf1.query_builders import BuildQuery +from wps_wf1.wfwx_settings import WfwxSettings + +logger = logging.getLogger(__name__) + +DEFAULT_TTL = 86400 + + +def _cache_key(url: str, params: Dict[str, Any]) -> str: + """ + Generate a key to use for caching from the provided url and parameter dictionary + + :param url: The URL + :param params: The key-value pairs to include in the cache key. + :return: A string representing the derived cache key. + """ + return f"{url}?{urlencode(params)}" + + +class WfwxClient: + def __init__( + self, session: ClientSession, settings: WfwxSettings, cache: Optional[CacheProtocol] = None + ): + self.session = session + self.settings = settings + self.cache = cache + + async def _get_json( + self, + url: str, + headers: Dict[str, Any], + params: Dict[str, Any], + use_cache: bool = True, + ttl: int = DEFAULT_TTL, + ) -> Dict[str, Any]: + key = _cache_key(url, params) + if use_cache and self.cache: + cached = self.cache.get(key) + if cached: + return json.loads(cached.decode("utf-8")) + + async with self.session.get(url, headers=headers, params=params) as resp: + resp.raise_for_status() + data = await resp.json() + + if use_cache and self.cache: + self.cache.set(key, json.dumps(data).encode("utf-8"), ex=ttl) + + return data + + async def fetch_access_token(self, ttl: int) -> Dict[str, Any]: + url = self.settings.auth_url + params = {"user": self.settings.user} + key = _cache_key(url, params) + + if self.cache: + cached = self.cache.get(key) + if cached: + return json.loads(cached.decode("utf-8")) + + async with self.session.get( + url, auth=BasicAuth(self.settings.user, self.settings.secret) + ) as resp: + resp.raise_for_status() + data = await resp.json() + + expires = min(data.get("expires_in", ttl), ttl) + if self.cache: + self.cache.set(key, json.dumps(data).encode("utf-8"), ex=expires) + + return data + + async def fetch_paged_response_generator( + self, + headers: Dict[str, Any], + query_builder: BuildQuery, + content_key: str, + use_cache: bool = False, + ttl: int = DEFAULT_TTL, + ) -> AsyncGenerator[Dict[str, Any], None]: + total_pages = 1 + page_count = 0 + while page_count < total_pages: + # Build up the request URL. + url, params = query_builder.query(page_count) + logger.debug("loading page %d...", page_count) + data = await self._get_json(url, headers, params, use_cache, ttl) + total_pages = data.get("page", {}).get("totalPages", 1) + for obj in data["_embedded"][content_key]: + yield obj + page_count += 1 + + async def fetch_raw_dailies_for_all_stations( + self, headers: Dict[str, Any], time_of_interest: datetime + ) -> list: + timestamp = int(time_of_interest.timestamp() * 1000) + params = { + "query": f"weatherTimestamp=={timestamp}", + "page": 0, + "size": self.settings.max_page_size, + } + url = f"{self.settings.base_url}/v1/dailies/rsql" + + total_pages = 1 + page_count = 0 + results = [] + while page_count < total_pages: + p = {**params, "page": page_count} + async with self.session.get(url, params=p, headers=headers) as resp: + resp.raise_for_status() + data = await resp.json() + total_pages = data["page"]["totalPages"] + results.extend(data["_embedded"]["dailies"]) + page_count += 1 + return results + + def prepare_fetch_hourlies_query( + self, raw_station: dict, start_datetime: datetime, end_datetime: datetime + ): + start_ts = int(start_datetime.timestamp() * 1000) + end_ts = int(end_datetime.timestamp() * 1000) + params = { + "startTimestamp": start_ts, + "endTimestamp": end_ts, + "stationId": raw_station["id"], + } + url = f"{self.settings.base_url}/v1/hourlies/search/findHourliesByWeatherTimestampBetweenAndStationIdEqualsOrderByWeatherTimestampAsc" + return url, params + + async def fetch_hourlies( + self, + raw_station: dict, + headers: Dict[str, Any], + start_datetime: datetime, + end_datetime: datetime, + use_cache: bool, + ttl: int, + ) -> dict: + url, params = self.prepare_fetch_hourlies_query(raw_station, start_datetime, end_datetime) + return await self._get_json(url, headers, params, use_cache, ttl) + + async def fetch_stations_by_group_id(self, headers: Dict[str, Any], group_id: str) -> dict: + url = f"{self.settings.base_url}/v1/stationGroups/{group_id}/members" + async with self.session.get(url, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() diff --git a/backend/packages/wps-wf1/src/wps_wf1/wfwx_settings.py b/backend/packages/wps-wf1/src/wps_wf1/wfwx_settings.py new file mode 100644 index 0000000000..fa14bb4663 --- /dev/null +++ b/backend/packages/wps-wf1/src/wps_wf1/wfwx_settings.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class WfwxSettings: + base_url: str + auth_url: str + user: str + secret: str + max_page_size: int = 1000 diff --git a/backend/pytest.ini b/backend/pytest.ini index a8d45f2bc7..1ffc72d217 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -3,6 +3,7 @@ testpaths = packages/wps-api/src packages/wps-jobs/src packages/wps-shared/src + packages/wps-wf1/src packages/wps-tools/tests python_files = test_*.py *_test.py python_classes = Test* @@ -13,3 +14,4 @@ pythonpath = packages/wps-jobs/src packages/wps-shared/src packages/wps-tools/src + packages/wps-wf1/src diff --git a/backend/uv.lock b/backend/uv.lock index 8e45f79033..2c47725764 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -13,6 +13,7 @@ members = [ "wps-shared", "wps-tools", "wps-weather", + "wps-wf1", ] [manifest.dependency-groups] @@ -4404,6 +4405,7 @@ dependencies = [ { name = "sentry-sdk" }, { name = "shapely" }, { name = "sqlalchemy" }, + { name = "wps-wf1" }, ] [package.optional-dependencies] @@ -4431,6 +4433,7 @@ requires-dist = [ { name = "shapely", specifier = ">=2.0.5,<3" }, { name = "sqlalchemy", specifier = ">=2,<3" }, { name = "testcontainers", extras = ["postgres"], marker = "extra == 'dev'", specifier = ">=4.10.0,<5" }, + { name = "wps-wf1", editable = "packages/wps-wf1" }, ] provides-extras = ["dev"] @@ -4490,6 +4493,17 @@ requires-dist = [ { name = "xarray", specifier = ">=2025.3.1,<2026" }, ] +[[package]] +name = "wps-wf1" +version = "0.1.0" +source = { editable = "packages/wps-wf1" } +dependencies = [ + { name = "aiohttp" }, +] + +[package.metadata] +requires-dist = [{ name = "aiohttp", specifier = ">=3.13.2" }] + [[package]] name = "wrapt" version = "1.17.3" From cf67166ff95dba3aab4da9b718a5763134e40f0e Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 14:13:40 -0800 Subject: [PATCH 02/10] Fix hourly tests --- .../wps-api/src/app/tests/jobs/test_hourly_actuals.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/packages/wps-api/src/app/tests/jobs/test_hourly_actuals.py b/backend/packages/wps-api/src/app/tests/jobs/test_hourly_actuals.py index 35eaa87843..3d13176b14 100644 --- a/backend/packages/wps-api/src/app/tests/jobs/test_hourly_actuals.py +++ b/backend/packages/wps-api/src/app/tests/jobs/test_hourly_actuals.py @@ -2,8 +2,10 @@ import math import os import logging +from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture +from wps_wf1.wfwx_client import WfwxClient from wps_shared.db.models.observations import HourlyActual from app.tests.jobs.job_fixtures import mock_wfwx_stations, mock_wfwx_response from wps_shared.utils.time import get_utc_now @@ -19,9 +21,13 @@ def mock_hourly_actuals(mocker: MockerFixture): """ Mocks out hourly actuals as async result """ wfwx_hourlies = mock_wfwx_response() future_wfwx_stations = mock_wfwx_stations() + mock_wfwx_client = MagicMock() + mock_wfwx_client.fetch_paged_response_generator = iter(wfwx_hourlies) mocker.patch("wps_shared.wildfire_one.wfwx_api.wfwx_station_list_mapper", return_value=future_wfwx_stations) mocker.patch("wps_shared.wildfire_one.wfwx_api.get_hourly_actuals_all_stations", return_value=wfwx_hourlies) - mocker.patch("wps_shared.wildfire_one.wildfire_fetchers.fetch_paged_response_generator", return_value=iter(wfwx_hourlies)) + mocker.patch( + "wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client", return_value=mock_wfwx_client + ) def test_hourly_actuals_job(monkeypatch, mocker: MockerFixture, mock_hourly_actuals): From 763b0c0785d32b57dfb4899cd9c612b4ecf723a3 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 14:17:34 -0800 Subject: [PATCH 03/10] Fix noon forecasts --- .../wps-api/src/app/tests/jobs/test_noon_forecasts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/packages/wps-api/src/app/tests/jobs/test_noon_forecasts.py b/backend/packages/wps-api/src/app/tests/jobs/test_noon_forecasts.py index 1794239f3f..e42b0512d6 100644 --- a/backend/packages/wps-api/src/app/tests/jobs/test_noon_forecasts.py +++ b/backend/packages/wps-api/src/app/tests/jobs/test_noon_forecasts.py @@ -1,6 +1,7 @@ """ Unit tests for the fireweather noon forecats job """ import os import logging +from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture from app.jobs import noon_forecasts @@ -15,10 +16,14 @@ def mock_noon_forecasts(mocker: MockerFixture): """ Mocks out noon forecasts as async result """ wfwx_hourlies = mock_wfwx_response() future_wfwx_stations = mock_wfwx_stations() + mock_wfwx_client = MagicMock() + mock_wfwx_client.fetch_paged_response_generator = iter(wfwx_hourlies) mocker.patch("wps_shared.wildfire_one.wfwx_api.wfwx_station_list_mapper", return_value=future_wfwx_stations) mocker.patch("wps_shared.wildfire_one.wfwx_api.get_noon_forecasts_all_stations", return_value=wfwx_hourlies) - mocker.patch("wps_shared.wildfire_one.wildfire_fetchers.fetch_paged_response_generator", return_value=iter(wfwx_hourlies)) + mocker.patch( + "wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client", return_value=mock_wfwx_client + ) def test_noon_forecasts_bot(monkeypatch, mocker: MockerFixture, mock_noon_forecasts): From fd1a759a1be9d2cd988371df2fe191dcc9e3ecab Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 14:29:33 -0800 Subject: [PATCH 04/10] Update dockerfile --- Dockerfile | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 06231e5426..ad08c5c4df 100644 --- a/Dockerfile +++ b/Dockerfile @@ -32,6 +32,8 @@ COPY ./backend/uv.lock /app/ COPY ./backend/packages/wps-api/pyproject.toml /app/packages/wps-api/ COPY ./backend/packages/wps-shared/pyproject.toml /app/packages/wps-shared/ COPY ./backend/packages/wps-shared/src /app/packages/wps-shared/src +COPY ./backend/packages/wps-wf1/pyproject.toml /app/packages/wps-wf1/ +COPY ./backend/packages/wps-wf1/src /app/packages/wps-wf1/src # Switch to root to set file permissions USER 0 @@ -39,8 +41,10 @@ USER 0 # Set configuration files to read-only for security RUN chmod 444 /app/pyproject.toml /app/uv.lock \ /app/packages/wps-api/pyproject.toml \ - /app/packages/wps-shared/pyproject.toml -RUN chmod -R a-w /app/packages/wps-shared/src + /app/packages/wps-shared/pyproject.toml \ + /app/packages/wps-wf1/pyproject.toml +RUN chmod -R a-w /app/packages/wps-shared/src \ + /app/packages/wps-wf1/src # Switch back to non-root user USER $USERNAME @@ -77,6 +81,7 @@ WORKDIR /app COPY --from=builder /app/pyproject.toml /app/ COPY --from=builder /app/packages/wps-api/pyproject.toml /app/packages/wps-api/ COPY --from=builder /app/packages/wps-shared/pyproject.toml /app/packages/wps-shared/ +COPY --from=builder /app/packages/wps-wf1/pyproject.toml /app/packages/wps-wf1/ # Switch back to our non-root user USER $USERNAME @@ -96,8 +101,9 @@ COPY ./backend/packages/wps-api/alembic.ini /app COPY ./backend/packages/wps-api/prestart.sh /app COPY ./backend/packages/wps-api/start.sh /app -# Make uv happy by copying wps_shared +# Make uv happy by copying wps_shared and wps_wf1 COPY ./backend/packages/wps-shared/src /app/packages/wps-shared/src +COPY ./backend/packages/wps-wf1/src /app/packages/wps-wf1/src # Copy installed Python packages COPY --from=builder /app/.venv /app/.venv @@ -115,7 +121,7 @@ ENV VIRTUAL_ENV="/app/.venv" # root user please USER 0 # Remove write permissions from copied configuration and source files for security -RUN chmod -R a-w /app/pyproject.toml /app/packages/wps-api/pyproject.toml /app/advisory /app/libs /app/alembic /app/alembic.ini /app/prestart.sh /app/start.sh /app/packages/wps-shared/src +RUN chmod -R a-w /app/pyproject.toml /app/packages/wps-api/pyproject.toml /app/advisory /app/libs /app/alembic /app/alembic.ini /app/prestart.sh /app/start.sh /app/packages/wps-shared/src /app/packages/wps-wf1/src # We don't know what user uv is going to run as, so we give everyone write access directories # in the app folder. We need write access for .pyc files to be created. .pyc files are good, # they speed up python. From 82fbbb45531d080af09a656f1d2217f6f2f38032 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 14:43:48 -0800 Subject: [PATCH 05/10] Add lookup --- .../wps-shared/src/wps_shared/tests/fixtures/wf1/lookup.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/packages/wps-shared/src/wps_shared/tests/fixtures/wf1/lookup.json b/backend/packages/wps-shared/src/wps_shared/tests/fixtures/wf1/lookup.json index 084f956ead..8d3c638b0e 100644 --- a/backend/packages/wps-shared/src/wps_shared/tests/fixtures/wf1/lookup.json +++ b/backend/packages/wps-shared/src/wps_shared/tests/fixtures/wf1/lookup.json @@ -102,6 +102,9 @@ "get": { "{'query': 'weatherTimestamp==1618862400000', 'page': 0, 'size': '1000'}": { "None": "wfwx/v1/dailies/rsql__query_weatherTimestamp==1618862400000_page_0_size_1000.json" + }, + "{'query': 'weatherTimestamp==1618862400000', 'page': 0, 'size': 1000}": { + "None": "wfwx/v1/dailies/rsql__query_weatherTimestamp==1618862400000_page_0_size_1000.json" } } }, From 0bcae0c79c7bf1c44f525e89affd4b20c10bb9f4 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 14:50:11 -0800 Subject: [PATCH 06/10] jobs dockerfile --- Dockerfile.jobs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/Dockerfile.jobs b/Dockerfile.jobs index 67a404052e..bd5f0e970d 100644 --- a/Dockerfile.jobs +++ b/Dockerfile.jobs @@ -30,6 +30,8 @@ COPY ./backend/uv.lock /app/ COPY ./backend/packages/wps-jobs/pyproject.toml /app/packages/wps-jobs/ COPY ./backend/packages/wps-shared/pyproject.toml /app/packages/wps-shared/ COPY ./backend/packages/wps-shared/src /app/packages/wps-shared/src +COPY ./backend/packages/wps-wf1/pyproject.toml /app/packages/wps-wf1/ +COPY ./backend/packages/wps-wf1/src /app/packages/wps-wf1/src # Switch to root to set file permissions USER 0 @@ -37,8 +39,10 @@ USER 0 # Set configuration files to read-only for security RUN chmod 444 /app/pyproject.toml /app/uv.lock \ /app/packages/wps-jobs/pyproject.toml \ - /app/packages/wps-shared/pyproject.toml -RUN chmod -R a-w /app/packages/wps-shared/src + /app/packages/wps-shared/pyproject.toml \ + /app/packages/wps-wf1/pyproject.toml +RUN chmod -R a-w /app/packages/wps-shared/src \ + /app/packages/wps-wf1/src # Switch back to non-root user USER $USERNAME @@ -75,6 +79,7 @@ WORKDIR /app COPY --from=builder /app/pyproject.toml /app/ COPY --from=builder /app/packages/wps-jobs/pyproject.toml /app/packages/wps-jobs/ COPY --from=builder /app/packages/wps-shared/pyproject.toml /app/packages/wps-shared/ +COPY --from=builder /app/packages/wps-wf1/pyproject.toml /app/packages/wps-wf1/ # Switch back to our non-root user USER $USERNAME @@ -82,6 +87,7 @@ USER $USERNAME # Copy the jobs from src layout: COPY ./backend/packages/wps-jobs/src /app COPY ./backend/packages/wps-shared/src /app/packages/wps-shared/src +COPY ./backend/packages/wps-wf1/src /app/packages/wps-wf1/src # Copy installed Python packages COPY --from=builder /app/.venv /app/.venv @@ -96,7 +102,7 @@ USER 0 # Create writable data directory for library caches (e.g., herbie BallTree) RUN mkdir -p /data && chmod 777 /data # Remove write permissions from copied configuration and source files for security -RUN chmod -R a-w /app/pyproject.toml /app/packages/wps-jobs/pyproject.toml /app/weather_model_jobs /app/packages/wps-shared/src +RUN chmod -R a-w /app/pyproject.toml /app/packages/wps-jobs/pyproject.toml /app/weather_model_jobs /app/packages/wps-shared/src /app/packages/wps-wf1/src # We don't know what user uv is going to run as, so we give everyone write access directories # in the app folder. We need write access for .pyc files to be created. .pyc files are good, # they speed up python. From e73a805f43545caa9a2d442062dbf4eab856eaca Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 15:08:19 -0800 Subject: [PATCH 07/10] Remove old query_builders --- .../wps_shared/wildfire_one/query_builders.py | 128 ------------------ 1 file changed, 128 deletions(-) delete mode 100644 backend/packages/wps-shared/src/wps_shared/wildfire_one/query_builders.py diff --git a/backend/packages/wps-shared/src/wps_shared/wildfire_one/query_builders.py b/backend/packages/wps-shared/src/wps_shared/wildfire_one/query_builders.py deleted file mode 100644 index 8a0b8b3209..0000000000 --- a/backend/packages/wps-shared/src/wps_shared/wildfire_one/query_builders.py +++ /dev/null @@ -1,128 +0,0 @@ -""" Query builder classes for making requests to WFWX API """ -from typing import List, Tuple -from abc import abstractmethod, ABC - -from wps_shared import config - - -class BuildQuery(ABC): - """ Base class for building query urls and params """ - - def __init__(self): - """ Initialize object """ - self.max_page_size = config.get('WFWX_MAX_PAGE_SIZE', 1000) - self.base_url = config.get('WFWX_BASE_URL') - - @abstractmethod - def query(self, page) -> Tuple[str, dict]: - """ Return query url and params """ - - -class BuildQueryStations(BuildQuery): - """ Class for building a url and RSQL params to request all active stations. """ - - def __init__(self): - """ Prepare filtering on active, test and project stations. """ - super().__init__() - self.param_query = None - # In conversation with Dana Hicks, on Apr 20, 2021 - Dana said to show active, test and project. - for status in ('ACTIVE', 'TEST', 'PROJECT'): - if self.param_query: - self.param_query += f',stationStatus.id=="{status}"' - else: - self.param_query = f'stationStatus.id=="{status}"' - - def query(self, page) -> Tuple[str, dict]: - """ Return query url and params with rsql query for all weather stations marked active. """ - params = {'size': self.max_page_size, 'sort': 'displayLabel', - 'page': page, 'query': self.param_query} - url = f'{self.base_url}/v1/stations' - return url, params - - -class BuildQueryByStationCode(BuildQuery): - """ Class for building a url and params to request a list of stations by code """ - - def __init__(self, station_codes: List[int]): - """ Initialize object """ - super().__init__() - self.querystring = '' - for code in station_codes: - if len(self.querystring) > 0: - self.querystring += ' or ' - self.querystring += f'stationCode=={code}' - - def query(self, page) -> Tuple[str, dict]: - """ Return query url and params for a list of stations """ - params = {'size': self.max_page_size, - 'sort': 'displayLabel', 'page': page, 'query': self.querystring} - url = f'{self.base_url}/v1/stations/rsql' - return url, params - - -class BuildQueryAllHourliesByRange(BuildQuery): - """ Builds query for requesting all hourlies in a time range""" - - def __init__(self, start_timestamp: int, end_timestamp: int): - """ Initialize object """ - super().__init__() - self.querystring: str = "weatherTimestamp >=" + \ - str(start_timestamp) + ";" + "weatherTimestamp <" + str(end_timestamp) - - def query(self, page) -> Tuple[str, dict]: - """ Return query url for hourlies between start_timestamp, end_timestamp""" - params = {'size': self.max_page_size, 'page': page, 'query': self.querystring} - url = f'{self.base_url}/v1/hourlies/rsql' - return url, params - - -class BuildQueryAllForecastsByAfterStart(BuildQuery): - """ Builds query for requesting all dailies in a time range""" - - def __init__(self, start_timestamp: int): - """ Initialize object """ - super().__init__() - self.querystring = f"weatherTimestamp >={start_timestamp};recordType.id == 'FORECAST'" - - def query(self, page) -> Tuple[str, dict]: - """ Return query url for dailies between start_timestamp, end_timestamp""" - params = {'size': self.max_page_size, 'page': page, 'query': self.querystring} - url = f'{self.base_url}/v1/dailies/rsql' - return url, params - - -class BuildQueryDailiesByStationCode(BuildQuery): - """ Builds query for requesting dailies in a time range for the station codes""" - - def __init__(self, start_timestamp: int, end_timestamp: int, station_ids: List[str]): - """ Initialize object """ - super().__init__() - self.start_timestamp = start_timestamp - self.end_timestamp = end_timestamp - self.station_ids = station_ids - - def query(self, page) -> Tuple[str, dict]: - """ Return query url for dailies between start_timestamp, end_timestamp""" - params = {'size': self.max_page_size, - 'page': page, - 'startingTimestamp': self.start_timestamp, - 'endingTimestamp': self.end_timestamp, - 'stationIds': self.station_ids} - url = (f'{self.base_url}/v1/dailies/search/findDailiesByStationIdIsInAndWeather' + - 'TimestampBetweenOrderByStationIdAscWeatherTimestampAsc') - return url, params - - -class BuildQueryStationGroups(BuildQuery): - """ Builds a query for requesting all station groups """ - - def __init__(self): - """ Initilize object. """ - super().__init__() - self.param_query = None - - def query(self, page) -> Tuple[str, dict]: - """ Return query url and params with query for all weather stations groups. """ - params = {'size': self.max_page_size, 'page': page, 'sort': 'groupOwnerUserId,asc'} - url = f'{self.base_url}/v1/stationGroups' - return url, params From d37969c73e2b80b12baef8dde047e5b6562dfbf6 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Tue, 30 Dec 2025 15:16:12 -0800 Subject: [PATCH 08/10] Remove tests for non-existent code --- .../tests/wildfire_one/test_wildfire_one.py | 46 +++---------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py index bece959c26..34eb202005 100644 --- a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py +++ b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wildfire_one.py @@ -7,55 +7,21 @@ import wps_shared.wildfire_one.wfwx_post_api from fastapi import HTTPException from pytest_mock import MockFixture -from wps_shared.wildfire_one.query_builders import ( - BuildQueryAllForecastsByAfterStart, - BuildQueryAllHourliesByRange, - BuildQueryDailiesByStationCode, - BuildQueryStationGroups, -) from wps_shared.wildfire_one.wfwx_api import ( WFWXWeatherStation, get_wfwx_stations_from_station_codes, ) from wps_shared.wildfire_one.wfwx_post_api import post_forecasts - -def test_build_all_hourlies_query(): - """Verifies the query builder returns the correct url and parameters""" - query_builder = BuildQueryAllHourliesByRange(0, 1) - result = query_builder.query(0) - assert result == ("https://wf1/wfwx/v1/hourlies/rsql", {"size": "1000", "page": 0, "query": "weatherTimestamp >=0;weatherTimestamp <1"}) - - -def test_build_forecasts_query(): - """Verifies the query builder returns the correct url and parameters""" - query_builder = BuildQueryAllForecastsByAfterStart(0) - result = query_builder.query(0) - assert result == ("https://wf1/wfwx/v1/dailies/rsql", {"size": "1000", "page": 0, "query": "weatherTimestamp >=0;recordType.id == 'FORECAST'"}) - - -def test_build_dailies_by_station_code(): - """Verifies the query builder returns the correct url and parameters for dailies by station code""" - query_builder = BuildQueryDailiesByStationCode(0, 1, ["1", "2"]) - result = query_builder.query(0) - assert result == ( - "https://wf1/wfwx/v1/dailies/search/" + "findDailiesByStationIdIsInAndWeather" + "TimestampBetweenOrderByStationIdAscWeatherTimestampAsc", - {"size": "1000", "page": 0, "startingTimestamp": 0, "endingTimestamp": 1, "stationIds": ["1", "2"]}, - ) - - -def test_build_station_groups_query(): - """Verifies the query builder returns the correct url and parameters for a station groups query""" - query_builder = BuildQueryStationGroups() - result = query_builder.query(0) - assert result == ("https://wf1/wfwx/v1/stationGroups", {"size": "1000", "page": 0, "sort": "groupOwnerUserId,asc"}) - - code1 = 322 code2 = 239 all_station_codes = [{"station_code": code1}, {"station_code": code2}] -station_1 = WFWXWeatherStation(code=code1, name="name", wfwx_id="one", latitude=0, longitude=0, elevation=0, zone_code="T1") -station_2 = WFWXWeatherStation(code=code2, name="name", wfwx_id="two", latitude=0, longitude=0, elevation=0, zone_code="T1") +station_1 = WFWXWeatherStation( + code=code1, name="name", wfwx_id="one", latitude=0, longitude=0, elevation=0, zone_code="T1" +) +station_2 = WFWXWeatherStation( + code=code2, name="name", wfwx_id="two", latitude=0, longitude=0, elevation=0, zone_code="T1" +) all_stations = [station_1, station_2] From 3b0d80b0844e5eea4462917d526ad42df5cbda86 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Wed, 31 Dec 2025 11:16:01 -0800 Subject: [PATCH 09/10] wfwx_api tests --- .../tests/wildfire_one/test_wfwx_api.py | 908 ++++++++++++++++++ 1 file changed, 908 insertions(+) create mode 100644 backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py diff --git a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py new file mode 100644 index 0000000000..2e04175e97 --- /dev/null +++ b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py @@ -0,0 +1,908 @@ +"""Unit tests for wfwx_api.py""" + +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp import ClientSession +from wps_shared.db.models.forecasts import NoonForecast +from wps_shared.db.models.observations import HourlyActual +from wps_shared.schemas.fba import FireCentre +from wps_shared.schemas.morecast_v2 import StationDailyFromWF1 +from wps_shared.schemas.stations import ( + GeoJsonDetailedWeatherStation, + WeatherStation, + WeatherVariables, +) +from wps_shared.tests.conftest import agen +from wps_shared.wildfire_one.schema_parsers import ( + WFWXWeatherStation, +) +from wps_shared.wildfire_one.wfwx_api import ( + get_auth_header, + get_dailies_for_stations_and_date, + get_dailies_generator, + get_daily_determinates_for_stations_and_date, + get_detailed_geojson_stations, + get_detailed_stations, + get_fire_centers, + get_forecasts_for_stations_by_date_range, + get_hourly_actuals_all_stations, + get_no_cache_auth_header, + get_noon_forecasts_all_stations, + get_raw_dailies_in_range_generator, + get_station_data, + get_station_groups, + get_stations_by_codes, + get_stations_by_group_ids, + get_wfwx_stations_from_station_codes, +) + + +class MockAsyncGenerator: + """Mock async generator for testing""" + + def __init__(self, items): + self.items = items + + def __aiter__(self): + return self + + async def __anext__(self): + if self.items: + return self.items.pop(0) + else: + raise StopAsyncIteration + + +class TestGetAuthHeader: + """Test cases for the get_auth_header function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.config.get") + async def test_get_auth_header(self, mock_config, mock_create_client): + """Test get_auth_header returns correct auth header""" + # Setup mocks + mock_config.return_value = "600" + mock_client = AsyncMock() + mock_client.fetch_access_token.return_value = {"access_token": "test_token"} + mock_create_client.return_value = mock_client + + session = MagicMock(spec=ClientSession) + + # Call function + result = await get_auth_header(session) + + # Verify result + assert result == {"Authorization": "Bearer test_token"} + mock_client.fetch_access_token.assert_called_once_with(600) + + +class TestGetNoCacheAuthHeader: + """Test cases for the get_no_cache_auth_header function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.get_auth_header") + async def test_get_no_cache_auth_header(self, mock_get_auth_header): + """Test get_no_cache_auth_header adds no-cache header""" + # Setup mocks + mock_get_auth_header.return_value = {"Authorization": "Bearer test_token"} + + session = MagicMock(spec=ClientSession) + + # Call function + result = await get_no_cache_auth_header(session) + + # Verify result + assert result == {"Authorization": "Bearer test_token", "Cache-Control": "no-cache"} + mock_get_auth_header.assert_called_once_with(session) + + +class TestGetStationsByCodes: + """Test cases for the get_stations_by_codes function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.EcodivisionSeasons") + @patch("wps_shared.wildfire_one.wfwx_api.ClientSession") + @patch("wps_shared.wildfire_one.wfwx_api.get_auth_header") + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.config.get") + @patch("wps_shared.wildfire_one.wfwx_api.is_station_valid") + @patch("wps_shared.wildfire_one.wfwx_api.parse_station") + async def test_get_stations_by_codes_with_valid_stations( + self, + mock_parse_station, + mock_is_station_valid, + mock_config, + mock_create_client, + mock_get_auth_header, + mock_client_session, + mock_eco_division, + ): + """Test get_stations_by_codes returns stations when they are valid""" + # Setup mocks + mock_config.return_value = "604800" + mock_get_auth_header.return_value = {"Authorization": "Bearer token"} + + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=agen( + [ + {"id": 1, "stationCode": 101, "displayLabel": "Station 1"}, + {"id": 2, "stationCode": 102, "displayLabel": "Station 2"}, + {"id": 3, "stationCode": 103, "displayLabel": "Station 3"}, + ] + ) + ) + + mock_create_client.return_value = mock_client + # Station 103 should not be included in the results + mock_is_station_valid.side_effect = lambda station: station.get("stationCode") != 103 + mock_parsed_station1 = MagicMock(spec=WeatherStation) + mock_parsed_station2 = MagicMock(spec=WeatherStation) + mock_parse_station.side_effect = [mock_parsed_station1, mock_parsed_station2] + + mock_eco_division_instance = MagicMock() + mock_eco_division.return_value.__enter__.return_value = mock_eco_division_instance + + # Mock ClientSession context manager + mock_session_instance = MagicMock() + mock_client_session.return_value.__aenter__.return_value = mock_session_instance + + # Call function + result = await get_stations_by_codes([101, 102, 103]) + + # Verify result + assert len(result) == 2 + assert result[0] == mock_parsed_station1 + assert result[1] == mock_parsed_station2 + + # Verify calls + mock_client_session.assert_called_once() + mock_get_auth_header.assert_called_once_with(mock_session_instance) + mock_create_client.assert_called_once_with(mock_session_instance) + mock_client.fetch_paged_response_generator.assert_called_once() + + +class TestGetStationData: + """Test cases for the get_station_data function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.ClientSession") + @patch("wps_shared.wildfire_one.wfwx_api.get_auth_header") + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.config.get") + async def test_get_station_data( + self, mock_config, mock_create_client, mock_get_auth_header, mock_client_session + ): + """Test get_station_data returns mapped stations""" + # Setup mocks + mock_config.return_value = "604800" + mock_get_auth_header.return_value = {"Authorization": "Bearer token"} + + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=agen( + [ + {"id": 1, "stationCode": 101, "displayLabel": "Station 1"}, + {"id": 2, "stationCode": 102, "displayLabel": "Station 2"}, + {"id": 3, "stationCode": 103, "displayLabel": "Station 3"}, + ] + ) + ) + mock_create_client.return_value = mock_client + + mock_mapper = AsyncMock(return_value=["mapped_station1", "mapped_station2"]) + + # Mock ClientSession context manager + mock_session_instance = MagicMock() + mock_client_session.return_value.__aenter__.return_value = mock_session_instance + + # Call function + result = await get_station_data( + mock_session_instance, {"Authorization": "Bearer token"}, mapper=mock_mapper + ) + + # Verify result + assert result == ["mapped_station1", "mapped_station2"] + + +class TestGetDetailedGeoJsonStations: + """Test cases for the get_detailed_geojson_stations function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.config.get") + @patch("wps_shared.wildfire_one.wfwx_api.is_station_valid") + async def test_get_detailed_geojson_stations_with_valid_stations( + self, mock_is_station_valid, mock_config, mock_create_client + ): + """Test get_detailed_geojson_stations returns correct mapping for valid stations""" + # Setup mocks + mock_config.return_value = "604800" + mock_is_station_valid.return_value = True + + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=agen( + [ + { + "id": "station1", + "stationCode": 101, + "displayLabel": "Station 1", + "longitude": -123.45, + "latitude": 49.28, + "stationStatus": {"id": "ACTIVE"}, + }, + { + "id": "station2", + "stationCode": 102, + "displayLabel": "Station 2", + "longitude": -123.46, + "latitude": 49.29, + "stationStatus": {"id": "ACTIVE"}, + }, + ] + ) + ) + mock_create_client.return_value = mock_client + + session = MagicMock(spec=ClientSession) + headers = {"Authorization": "Bearer token"} + query_builder = MagicMock() + + # Call function + stations, id_to_code_map = await get_detailed_geojson_stations( + session, headers, query_builder + ) + + # Verify result + assert len(stations) == 2 + assert 101 in stations + assert 102 in stations + assert stations[101].properties.code == 101 + assert stations[101].properties.name == "Station 1" + assert stations[101].geometry.coordinates == [-123.45, 49.28] + + assert id_to_code_map == {"station1": 101, "station2": 102} + + +class TestGetDetailedStations: + """Test cases for the get_detailed_stations function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.get_auth_header") + @patch("wps_shared.wildfire_one.wfwx_api._get_noon_date") + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.get_detailed_geojson_stations") + @patch("wps_shared.wildfire_one.wfwx_api.ClientSession") + @patch("wps_shared.wildfire_one.wfwx_api.TCPConnector") + @patch("wps_shared.wildfire_one.wfwx_api.WeatherVariables") + async def test_get_detailed_stations( + self, + mock_weather_variables, + mock_tcp_connector, + mock_client_session, + mock_get_detailed_geojson_stations, + mock_create_client, + mock_get_noon_date, + mock_get_auth_header, + ): + """Test get_detailed_stations returns stations with weather data""" + # Setup mocks + time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + noon_time = datetime(2023, 1, 1, 12, 0, 0) + mock_get_noon_date.return_value = noon_time + mock_get_auth_header.return_value = {"Authorization": "Bearer token"} + + mock_client = AsyncMock() + mock_client.fetch_raw_dailies_for_all_stations.return_value = [ + { + "stationId": "station1", + "temperature": 20.5, + "relativeHumidity": 65.0, + "recordType": {"id": "ACTUAL"}, + } + ] + mock_create_client.return_value = mock_client + + # Mock geojson stations + station1 = MagicMock(spec=GeoJsonDetailedWeatherStation) + station1.properties = MagicMock() + station1.properties.observations = None + station1.properties.forecasts = None + + stations_dict = {101: station1} + id_to_code_map = {"station1": 101} + mock_get_detailed_geojson_stations.return_value = (stations_dict, id_to_code_map) + + mock_weather_variables.return_value = MagicMock(spec=WeatherVariables) + + # Mock TCPConnector + mock_connector = MagicMock() + mock_tcp_connector.return_value = mock_connector + + # Mock ClientSession context manager + mock_session_instance = MagicMock() + mock_client_session.return_value.__aenter__.return_value = mock_session_instance + + # Call function + result = await get_detailed_stations(time_of_interest) + + # Verify result + assert len(result) == 1 + assert result[0] == station1 + assert station1.properties.observations is not None + + +class TestGetNoonForecastsAllStations: + """Test cases for the get_noon_forecasts_all_stations function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.get_station_data") + @patch("wps_shared.wildfire_one.wfwx_api.parse_noon_forecast") + async def test_get_noon_forecasts_all_stations( + self, mock_parse_noon_forecast, mock_get_station_data, mock_create_client + ): + """Test get_noon_forecasts_all_stations returns forecasts for all stations""" + # Setup mocks + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=agen( + [{"stationId": "wfwx1", "temp": 20.0}, {"stationId": "wfwx2", "temp": 22.0}] + ) + ) + mock_create_client.return_value = mock_client + + mock_station1 = MagicMock(spec=WFWXWeatherStation) + mock_station1.wfwx_id = "wfwx1" + mock_station1.code = 101 + mock_station2 = MagicMock(spec=WFWXWeatherStation) + mock_station2.wfwx_id = "wfwx2" + mock_station2.code = 102 + mock_get_station_data.return_value = [mock_station1, mock_station2] + + mock_forecast1 = MagicMock(spec=NoonForecast) + mock_forecast2 = MagicMock(spec=NoonForecast) + mock_parse_noon_forecast.side_effect = [mock_forecast1, mock_forecast2] + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + start_timestamp = datetime(2023, 1, 1, 0, 0, 0) + + # Call function + result = await get_noon_forecasts_all_stations(session, header, start_timestamp) + + # Verify result + assert len(result) == 2 + assert result[0] == mock_forecast1 + assert result[1] == mock_forecast2 + + +class TestGetHourlyActualsAllStations: + """Test cases for the get_hourly_actuals_all_stations function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.get_station_data") + @patch("wps_shared.wildfire_one.wfwx_api.parse_hourly_actual") + async def test_get_hourly_actuals_all_stations( + self, mock_parse_hourly_actual, mock_get_station_data, mock_create_client + ): + """Test get_hourly_actuals_all_stations returns actuals for all stations""" + # Setup mocks + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=agen( + [ + { + "stationId": "wfwx1", + "hourlyMeasurementTypeCode": {"id": "ACTUAL"}, + "temp": 20.0, + }, + { + "stationId": "wfwx2", + "hourlyMeasurementTypeCode": {"id": "FORECAST"}, # Should be filtered out + "temp": 22.0, + }, + { + "stationId": "wfwx1", + "hourlyMeasurementTypeCode": {"id": "ACTUAL"}, + "temp": 21.0, + }, + ] + ) + ) + mock_create_client.return_value = mock_client + + mock_station1 = MagicMock(spec=WFWXWeatherStation) + mock_station1.wfwx_id = "wfwx1" + mock_station1.code = 101 + mock_station2 = MagicMock(spec=WFWXWeatherStation) + mock_station2.wfwx_id = "wfwx2" + mock_station2.code = 102 + mock_get_station_data.return_value = [mock_station1, mock_station2] + + mock_actual1 = MagicMock(spec=HourlyActual) + mock_actual2 = MagicMock(spec=HourlyActual) + mock_parse_hourly_actual.side_effect = [mock_actual1, mock_actual2] + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + start_timestamp = datetime(2023, 1, 1, 0, 0, 0) + end_timestamp = datetime(2023, 1, 1, 23, 59, 59) + + # Call function + result = await get_hourly_actuals_all_stations( + session, header, start_timestamp, end_timestamp + ) + + # Verify result - should only include ACTUAL records (2 out of 3) + assert len(result) == 2 + assert result[0] == mock_actual1 + assert result[1] == mock_actual2 + + +class TestGetWfwxStationsFromStationCodes: + """Test cases for the get_wfwx_stations_from_station_codes function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.get_station_data") + @patch("wps_shared.wildfire_one.wfwx_api.get_fire_centre_station_codes") + async def test_get_wfwx_stations_from_station_codes_with_specific_codes( + self, mock_get_fire_centre_station_codes, mock_get_station_data + ): + """Test get_wfwx_stations_from_station_codes returns specific stations when codes provided""" + # Setup mocks + mock_get_fire_centre_station_codes.return_value = [101, 102, 103, 104] + + mock_station1 = MagicMock(spec=WFWXWeatherStation) + mock_station1.code = 101 + mock_station1.wfwx_id = "wfwx1" + mock_station2 = MagicMock(spec=WFWXWeatherStation) + mock_station2.code = 102 + mock_station2.wfwx_id = "wfwx2" + mock_station3 = MagicMock(spec=WFWXWeatherStation) + mock_station3.code = 103 + mock_station3.wfwx_id = "wfwx3" + mock_get_station_data.return_value = [mock_station1, mock_station2, mock_station3] + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + station_codes = [101, 102] + + # Call function + result = await get_wfwx_stations_from_station_codes(session, header, station_codes) + + # Verify result + assert len(result) == 2 + assert result[0] == mock_station1 + assert result[1] == mock_station2 + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.get_station_data") + @patch("wps_shared.wildfire_one.wfwx_api.get_fire_centre_station_codes") + async def test_get_wfwx_stations_from_station_codes_none_returns_all( + self, mock_get_fire_centre_station_codes, mock_get_station_data + ): + """Test get_wfwx_stations_from_station_codes returns all fire centre stations when None provided""" + # Setup mocks + mock_get_fire_centre_station_codes.return_value = [ + 101, + 102, + ] # Only station 101, 102 are in fire centre + + mock_station1 = MagicMock(spec=WFWXWeatherStation) + mock_station1.code = 101 + mock_station1.wfwx_id = "wfwx1" + mock_station2 = MagicMock(spec=WFWXWeatherStation) + mock_station2.code = 102 + mock_station2.wfwx_id = "wfwx2" + mock_station3 = MagicMock(spec=WFWXWeatherStation) + mock_station3.code = 103 # Not in fire centre + mock_station3.wfwx_id = "wfwx3" + mock_get_station_data.return_value = [mock_station1, mock_station2, mock_station3] + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + station_codes = None + + # Call function + result = await get_wfwx_stations_from_station_codes(session, header, station_codes) + + # Verify result - should only include fire centre stations + assert len(result) == 2 + assert result[0] == mock_station1 + assert result[1] == mock_station2 + + +class TestGetRawDailiesInRangeGenerator: + """Test cases for the get_raw_dailies_in_range_generator function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + async def test_get_raw_dailies_in_range_generator(self, mock_create_client): + """Test get_raw_dailies_in_range_generator returns correct generator""" + # Setup mocks + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=MockAsyncGenerator( + [{"id": "daily1", "temp": 20.0}, {"id": "daily2", "temp": 22.0}] + ) + ) + mock_create_client.return_value = mock_client + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + wfwx_station_ids = ["wfwx1", "wfwx2"] + start_timestamp = 1672531200000 # 2023-01-01 00:00:00 in milliseconds + end_timestamp = 1672617600000 # 2023-01-02 00:00:00 in milliseconds + + # Call function + result = await get_raw_dailies_in_range_generator( + session, header, wfwx_station_ids, start_timestamp, end_timestamp + ) + + # Verify result is the generator + assert isinstance(result, MockAsyncGenerator) + + # Verify client was called correctly + mock_create_client.assert_called_once_with(session) + + +class TestGetDailiesGenerator: + """Test cases for the get_dailies_generator function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + @patch("wps_shared.wildfire_one.wfwx_api.config.get") + async def test_get_dailies_generator_with_cache(self, mock_config, mock_create_client): + """Test get_dailies_generator uses cache when enabled""" + mock_mapping = {"REDIS_DAILIES_BY_STATION_CODE_CACHE_EXPIRY": "300", "REDIS_USE": "True"} + + # Setup mocks + def mock_config_get(key, default=None): + return mock_mapping.get(key, default) + + mock_config.side_effect = mock_config_get + + mock_client = MagicMock() + mock_client.fetch_paged_response_generator = MagicMock( + return_value=MockAsyncGenerator([{"id": "daily1", "temp": 20.0}]) + ) + mock_create_client.return_value = mock_client + + mock_station = MagicMock(spec=WFWXWeatherStation) + mock_station.wfwx_id = "wfwx1" + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + wfwx_stations = [mock_station] + time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + end_time_of_interest = datetime(2023, 1, 2, 12, 0, 0) + check_cache = True + + # Call function + result = await get_dailies_generator( + session, header, wfwx_stations, time_of_interest, end_time_of_interest, check_cache + ) + + # Verify result + assert isinstance(result, MockAsyncGenerator) + + # Verify cache was used + mock_client.fetch_paged_response_generator.assert_called_once() + + +class TestGetFireCenters: + """Test cases for the get_fire_centers function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.get_station_data") + async def test_get_fire_centers(self, mock_get_station_data): + """Test get_fire_centers returns list of fire centres""" + # Setup mocks + mock_fire_center1 = MagicMock(spec=FireCentre) + mock_fire_center2 = MagicMock(spec=FireCentre) + fire_centers_dict = {"center1": mock_fire_center1, "center2": mock_fire_center2} + mock_get_station_data.return_value = fire_centers_dict + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + + # Call function + result = await get_fire_centers(session, header) + + # Verify result + assert len(result) == 2 + assert mock_fire_center1 in result + assert mock_fire_center2 in result + + +class TestGetDailiesForStationsAndDate: + """Test cases for the get_dailies_for_stations_and_date function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.is_station_valid") + @patch("wps_shared.wildfire_one.wfwx_api.get_wfwx_stations_from_station_codes") + @patch("wps_shared.wildfire_one.wfwx_api.get_dailies_generator") + async def test_get_dailies_for_stations_and_date( + self, mock_get_dailies_generator, mock_get_wfwx_stations, mock_is_station_valid + ): + """Test get_dailies_for_stations_and_date returns mapped dailies""" + # Setup mocks + mock_wfwx_station = MagicMock(spec=WFWXWeatherStation) + mock_wfwx_station.wfwx_id = "wfwx1" + mock_get_wfwx_stations.return_value = [mock_wfwx_station] + mock_is_station_valid.return_value = True + mock_dailies_list_mapper_result = ["mapped_daily1", "mapped_daily2"] + mock_dailies_list_mapper = AsyncMock(return_value=mock_dailies_list_mapper_result) + + # Need to access and modify the __defaults__ of the function being called (get_dailies_for_stations_and_date) + original_defaults = get_dailies_for_stations_and_date.__defaults__ + try: + new_defaults = list(original_defaults or []) + if len(new_defaults) == 0: + # If there were no defaults, we must build a defaults tuple that + # matches number of rightmost defaulted params. + new_defaults = [mock_dailies_list_mapper] + else: + # Replace the last default (mapper) + new_defaults[-1] = mock_dailies_list_mapper + get_dailies_for_stations_and_date.__defaults__ = tuple(new_defaults) + + mock_dailies_generator = MockAsyncGenerator( + [{"id": "daily1", "temp": 20.0}, {"id": "daily2", "temp": 22.0}] + ) + mock_get_dailies_generator.return_value = mock_dailies_generator + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + start_time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + end_time_of_interest = datetime(2023, 1, 2, 12, 0, 0) + unique_station_codes = [101, 102] + + # Call function + result = await get_dailies_for_stations_and_date( + session, header, start_time_of_interest, end_time_of_interest, unique_station_codes + ) + + # Verify result + assert result == mock_dailies_list_mapper_result + assert mock_dailies_list_mapper.await_count == 1 + finally: + # Always restore defaults to avoid test bleed-over + get_dailies_for_stations_and_date.__defaults__ = original_defaults + + +class TestGetForecastsForStationsByDateRange: + """Test cases for the get_forecasts_for_stations_by_date_range function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.is_station_valid") + @patch("wps_shared.wildfire_one.wfwx_api.get_wfwx_stations_from_station_codes") + @patch("wps_shared.wildfire_one.wfwx_api.get_dailies_generator") + async def test_get_forecasts_for_stations_by_date_range( + self, mock_get_dailies_generator, mock_get_wfwx_stations, mock_is_station_valid + ): + """Test get_forecasts_for_stations_by_date_range returns forecast dailies""" + # Setup mocks + mock_wfwx_station = MagicMock(spec=WFWXWeatherStation) + mock_wfwx_station.wfwx_id = "wfwx1" + mock_get_wfwx_stations.return_value = [mock_wfwx_station] + mock_is_station_valid.return_value = True + + mock_dailies_generator = MockAsyncGenerator( + [{"id": "forecast1", "temp": 20.0}, {"id": "forecast2", "temp": 22.0}] + ) + mock_get_dailies_generator.return_value = mock_dailies_generator + mock_dailies_list_mapper_result = [ + MagicMock(spec=StationDailyFromWF1), + MagicMock(spec=StationDailyFromWF1), + ] + mock_dailies_list_mapper = AsyncMock(return_value=mock_dailies_list_mapper_result) + + # Need to access and modify the __defaults__ of the function being called (get_dailies_for_stations_and_date) + original_defaults = get_forecasts_for_stations_by_date_range.__defaults__ + try: + new_defaults = list(original_defaults or []) + if len(new_defaults) == 0: + # If there were no defaults, we must build a defaults tuple that + # matches number of rightmost defaulted params. + new_defaults = [mock_dailies_list_mapper] + else: + # Replace the last default (mapper) + new_defaults[-1] = mock_dailies_list_mapper + get_forecasts_for_stations_by_date_range.__defaults__ = tuple(new_defaults) + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + start_time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + end_time_of_interest = datetime(2023, 1, 2, 12, 0, 0) + unique_station_codes = [101, 102] + + # Call function + result = await get_forecasts_for_stations_by_date_range( + session, header, start_time_of_interest, end_time_of_interest, unique_station_codes + ) + + # Verify result + assert result == mock_dailies_list_mapper_result + assert mock_dailies_list_mapper.await_count == 1 + finally: + # Always restore defaults to avoid test bleed-over + get_forecasts_for_stations_by_date_range.__defaults__ = original_defaults + + +class TestGetDailyDeterminatesForStationsAndDate: + """Test cases for the get_daily_determinates_for_stations_and_date function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.get_wfwx_stations_from_station_codes") + @patch("wps_shared.wildfire_one.wfwx_api.get_dailies_generator") + async def test_get_daily_determinates_for_stations_and_date( + self, mock_get_dailies_generator, mock_get_wfwx_stations + ): + """Test get_daily_determinates_for_stations_and_date returns actuals and forecasts""" + # Setup mocks + mock_wfwx_station = MagicMock(spec=WFWXWeatherStation) + mock_wfwx_station.wfwx_id = "wfwx1" + mock_get_wfwx_stations.return_value = [mock_wfwx_station] + + mock_dailies_generator = MockAsyncGenerator( + [{"id": "daily1", "temp": 20.0}, {"id": "daily2", "temp": 22.0}] + ) + mock_get_dailies_generator.return_value = mock_dailies_generator + + mock_actuals = ["actual1", "actual2"] + mock_forecasts = ["forecast1", "forecast2"] + mock_weather_indeterminate_list_mapper = AsyncMock( + return_value=(mock_actuals, mock_forecasts) + ) + + # Need to access and modify the __defaults__ of the function being called (get_dailies_for_stations_and_date) + original_defaults = get_daily_determinates_for_stations_and_date.__defaults__ + try: + new_defaults = list(original_defaults or []) + if len(new_defaults) == 0: + # If there were no defaults, we must build a defaults tuple that + # matches number of rightmost defaulted params. + new_defaults = [mock_weather_indeterminate_list_mapper] + else: + # Replace the first default (mapper) + new_defaults[0] = mock_weather_indeterminate_list_mapper + get_daily_determinates_for_stations_and_date.__defaults__ = tuple(new_defaults) + + session = MagicMock(spec=ClientSession) + header = {"Authorization": "Bearer token"} + start_time_of_interest = datetime(2023, 1, 1, 12, 0, 0) + end_time_of_interest = datetime(2023, 1, 2, 12, 0, 0) + unique_station_codes = [101, 102] + + # Call function + result = await get_daily_determinates_for_stations_and_date( + session, header, start_time_of_interest, end_time_of_interest, unique_station_codes + ) + + # Verify result + assert result == (mock_actuals, mock_forecasts) + assert mock_weather_indeterminate_list_mapper.await_count == 1 + finally: + # Always restore defaults to avoid test bleed-over + get_daily_determinates_for_stations_and_date.__defaults__ = original_defaults + + +class TestGetStationGroups: + """Test cases for the get_station_groups function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.ClientSession") + @patch("wps_shared.wildfire_one.wfwx_api.get_auth_header") + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + async def test_get_station_groups( + self, mock_create_client, mock_get_auth_header, mock_client_session + ): + """Test get_station_groups returns mapped station groups""" + # Setup mocks + mock_get_auth_header.return_value = {"Authorization": "Bearer token"} + + mock_client = AsyncMock() + mock_client.fetch_paged_response_generator.return_value = MockAsyncGenerator( + [{"id": "group1", "name": "Group 1"}, {"id": "group2", "name": "Group 2"}] + ) + mock_create_client.return_value = mock_client + + mock_weather_station_group_mapper_value = ["mapped_group1", "mapped_group2"] + mock_weather_station_group_mapper = AsyncMock( + return_value=mock_weather_station_group_mapper_value + ) + + # Need to access and modify the __defaults__ of the function being called (get_dailies_for_stations_and_date) + original_defaults = get_station_groups.__defaults__ + try: + new_defaults = list(original_defaults or []) + if len(new_defaults) == 0: + # If there were no defaults, we must build a defaults tuple that + # matches number of rightmost defaulted params. + new_defaults = [mock_weather_station_group_mapper] + else: + # Replace the last (only) default (mapper) + new_defaults[-1] = mock_weather_station_group_mapper + get_station_groups.__defaults__ = tuple(new_defaults) + + # Mock ClientSession context manager + mock_session_instance = MagicMock() + mock_client_session.return_value.__aenter__.return_value = mock_session_instance + + # Call function + result = await get_station_groups() + + # Verify result + assert result == mock_weather_station_group_mapper_value + finally: + # Always restore defaults to avoid test bleed-over + get_station_groups.__defaults__ = original_defaults + + +class TestGetStationsByGroupIds: + """Test cases for the get_stations_by_group_ids function""" + + @pytest.mark.anyio + @patch("wps_shared.wildfire_one.wfwx_api.ClientSession") + @patch("wps_shared.wildfire_one.wfwx_api.get_auth_header") + @patch("wps_shared.wildfire_one.wfwx_api.create_wps_wf1_client") + async def test_get_stations_by_group_ids( + self, mock_create_client, mock_get_auth_header, mock_client_session + ): + """Test get_stations_by_group_ids returns stations from all groups""" + # Setup mocks + mock_get_auth_header.return_value = {"Authorization": "Bearer token"} + + mock_client = AsyncMock() + mock_client.fetch_stations_by_group_id.side_effect = [ + {"_embedded": {"stations": [{"id": "station1"}]}}, + {"_embedded": {"stations": [{"id": "station2"}]}}, + ] + mock_create_client.return_value = mock_client + + # Mock ClientSession context manager + mock_session_instance = MagicMock() + mock_client_session.return_value.__aenter__.return_value = mock_session_instance + + # Call function + group_ids = ["group1"] + + mock_mapped_stations_group1 = "mapped_station1" + mock_mapped_stations_group2 = "mapped_station2" + mock_unique_weather_stations_mapper_value = [ + mock_mapped_stations_group1, + mock_mapped_stations_group2, + ] + mock_unique_weather_stations_mapper = MagicMock( + return_value=mock_unique_weather_stations_mapper_value + ) + + # Need to access and modify the __defaults__ of the function being called (get_dailies_for_stations_and_date) + original_defaults = get_stations_by_group_ids.__defaults__ + try: + new_defaults = list(original_defaults or []) + if len(new_defaults) == 0: + # If there were no defaults, we must build a defaults tuple that + # matches number of rightmost defaulted params. + new_defaults = [mock_unique_weather_stations_mapper] + else: + # Replace the last (only) default (mapper) + new_defaults[-1] = mock_unique_weather_stations_mapper + get_stations_by_group_ids.__defaults__ = tuple(new_defaults) + + result = await get_stations_by_group_ids(group_ids) + + # Verify result + assert result == ["mapped_station1", "mapped_station2"] + assert mock_client.fetch_stations_by_group_id.call_count == 1 + assert mock_unique_weather_stations_mapper.call_count == 1 + finally: + # Always restore defaults to avoid test bleed-over + get_stations_by_group_ids.__defaults__ = original_defaults From 722e646ce1071f4c6f400a88a93cd6bf5f0df955 Mon Sep 17 00:00:00 2001 From: Darren Boss Date: Wed, 31 Dec 2025 12:38:20 -0800 Subject: [PATCH 10/10] Use MockAsyncGenerator --- .../wps_shared/tests/wildfire_one/test_wfwx_api.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py index 2e04175e97..4791f2bf75 100644 --- a/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py +++ b/backend/packages/wps-shared/src/wps_shared/tests/wildfire_one/test_wfwx_api.py @@ -14,7 +14,6 @@ WeatherStation, WeatherVariables, ) -from wps_shared.tests.conftest import agen from wps_shared.wildfire_one.schema_parsers import ( WFWXWeatherStation, ) @@ -127,7 +126,7 @@ async def test_get_stations_by_codes_with_valid_stations( mock_client = MagicMock() mock_client.fetch_paged_response_generator = MagicMock( - return_value=agen( + return_value=MockAsyncGenerator( [ {"id": 1, "stationCode": 101, "displayLabel": "Station 1"}, {"id": 2, "stationCode": 102, "displayLabel": "Station 2"}, @@ -183,7 +182,7 @@ async def test_get_station_data( mock_client = MagicMock() mock_client.fetch_paged_response_generator = MagicMock( - return_value=agen( + return_value=MockAsyncGenerator( [ {"id": 1, "stationCode": 101, "displayLabel": "Station 1"}, {"id": 2, "stationCode": 102, "displayLabel": "Station 2"}, @@ -225,7 +224,7 @@ async def test_get_detailed_geojson_stations_with_valid_stations( mock_client = MagicMock() mock_client.fetch_paged_response_generator = MagicMock( - return_value=agen( + return_value=MockAsyncGenerator( [ { "id": "station1", @@ -350,7 +349,7 @@ async def test_get_noon_forecasts_all_stations( # Setup mocks mock_client = MagicMock() mock_client.fetch_paged_response_generator = MagicMock( - return_value=agen( + return_value=MockAsyncGenerator( [{"stationId": "wfwx1", "temp": 20.0}, {"stationId": "wfwx2", "temp": 22.0}] ) ) @@ -395,7 +394,7 @@ async def test_get_hourly_actuals_all_stations( # Setup mocks mock_client = MagicMock() mock_client.fetch_paged_response_generator = MagicMock( - return_value=agen( + return_value=MockAsyncGenerator( [ { "stationId": "wfwx1",