diff --git a/tests/test_sdk.py b/tests/test_sdk.py index 4a7f952b..10cddb0a 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -143,6 +143,18 @@ def test_mock_register(self, mock_post): resp = self.base.register(email=os.getenv("WATTTIME_EMAIL")) self.assertEqual(len(mock_post.call_args_list), 1) + def test_get_password(self): + + with mock.patch.dict(os.environ, {}, clear=True), self.assertRaises(ValueError): + wt_base = WattTimeBase() + + with mock.patch.dict(os.environ, {}, clear=True): + wt_base = WattTimeBase( + username="WATTTIME_USERNAME", password="WATTTIME_PASSWORD" + ) + self.assertEqual(wt_base.username, "WATTTIME_USERNAME") + self.assertEqual(wt_base.password, "WATTTIME_PASSWORD") + class TestWattTimeHistorical(unittest.TestCase): def setUp(self): @@ -464,6 +476,11 @@ def test_historical_forecast_jsons_multithreaded(self): class TestWattTimeMaps(unittest.TestCase): def setUp(self): self.maps = WattTimeMaps() + self.myaccess = WattTimeMyAccess() + + def tearDown(self): + self.maps.session.close() + self.myaccess.session.close() def tearDown(self): self.maps.session.close() @@ -503,6 +520,22 @@ def test_region_from_loc(self): self.assertEqual(region["region_full_name"], "Public Service Co of Colorado") self.assertEqual(region["signal_type"], "co2_moer") + def test_my_access_in_geojson(self): + access = self.myaccess.get_access_pandas() + for signal_type in ["co2_moer", "co2_aoer", "health_damage"]: + access_regions = access.loc[ + access["signal_type"] == signal_type, "region" + ].unique() + maps = self.maps.get_maps_json(signal_type=signal_type) + maps_regions = [i["properties"]["region"] for i in maps["features"]] + + assert ( + set(access_regions) - set(maps_regions) == set() + ), f"Missing regions in geojson for {signal_type}: {set(access_regions) - set(maps_regions)}" + assert ( + set(maps_regions) - set(access_regions) == set() + ), f"Extra regions in geojson for {signal_type}: {set(maps_regions) - set(access_regions)}" + class TestWattTimeMarginalFuelMix(unittest.TestCase): def setUp(self): diff --git a/watttime/api.py b/watttime/api.py index b84b575e..db70522b 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -2,12 +2,16 @@ import time import threading import time +import logging +from collections import defaultdict from datetime import date, datetime, timedelta, time as dt_time from collections import defaultdict from functools import cache from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union from concurrent.futures import ThreadPoolExecutor, as_completed +from urllib3.util.retry import Retry +from requests.adapters import HTTPAdapter import pandas as pd import requests @@ -15,6 +19,40 @@ from pytz import UTC +class WattTimeAPIWarning: + def __init__(self, url: str, params: Dict[str, Any], warning_message: str): + self.url = url + self.params = params + self.warning_message = warning_message + + def __repr__(self): + return f"\n" + + def to_dict(self) -> Dict[str, Any]: + def stringify(value: Any) -> Any: + if isinstance(value, datetime): + return value.isoformat() + return value + + return { + "url": self.url, + "params": {k: stringify(v) for k, v in self.params.items()}, + "warning_message": self.warning_message, + } + + +def get_log(): + logging.basicConfig( + format="%(asctime)s [%(levelname)-1s] " "%(message)s", + level=logging.INFO, + handlers=[logging.StreamHandler()], + ) + return logging.getLogger() + + +LOG = get_log() + + class WattTimeBase: url_base = os.getenv("WATTTIME_API_URL", "https://api.watttime.org") @@ -37,8 +75,17 @@ def __init__( worker_count (int): The number of worker threads to use for multithreading. Default is min(10, (os.cpu_count() or 1) * 2). """ - self.username = username or os.getenv("WATTTIME_USER") - self.password = password or os.getenv("WATTTIME_PASSWORD") + + # This only applies to the current session, is not stored persistently + if username and not os.getenv("WATTTIME_USER"): + os.environ["WATTTIME_USER"] = username + if password and not os.getenv("WATTTIME_PASSWORD"): + os.environ["WATTTIME_PASSWORD"] = password + + # Accessing attributes will raise exception if variables are not set + _ = self.password + _ = self.username + self.token = None self.headers = None self.token_valid_until = None @@ -47,6 +94,7 @@ def __init__( self.rate_limit = rate_limit self._last_request_times = [] self.worker_count = worker_count + self.raised_warnings: List[WattTimeAPIWarning] = [] if self.multithreaded: self._rate_limit_lock = ( @@ -54,7 +102,39 @@ def __init__( ) # prevent multiple threads from modifying _last_request_times simultaneously self._rate_limit_condition = threading.Condition(self._rate_limit_lock) + retry_strategy = Retry( + total=3, + status_forcelist=[500, 502, 503, 504], + backoff_factor=1, + raise_on_status=False, + ) + + adapter = HTTPAdapter(max_retries=retry_strategy) self.session = requests.Session() + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + @property + def password(self): + password = os.getenv("WATTTIME_PASSWORD") + if not password: + raise ValueError( + "WATTTIME_PASSWORD env variable is not set." + + "Please set this variable, or pass in a password upon initialization," + + "which will store it as a variable only for the current session" + ) + return password + + @property + def username(self): + username = os.getenv("WATTTIME_USER") + if not username: + raise ValueError( + "WATTTIME_USER env variable is not set." + + "Please set this variable, or pass in a username upon initialization," + + "which will store it as a variable only for the current session" + ) + return username def _login(self): """ @@ -158,7 +238,7 @@ def register(self, email: str, organization: Optional[str] = None) -> None: rsp = self.session.post(url, json=params, timeout=(10, 60)) rsp.raise_for_status() - print( + LOG.info( f"Successfully registered {self.username}, please check {email} for a verification email" ) @@ -222,10 +302,19 @@ def _make_rate_limited_request(self, url: str, params: Dict[str, Any]) -> Dict: f"API Request Failed: {e}\nURL: {url}\nParams: {params}" ) from e - if j.get("meta", {}).get("warnings"): - print("Warnings Returned: %s | Response: %s", params, j["meta"]) + meta = j.get("meta", {}) + warnings = meta.get("warnings") + if warnings: + for warning_message in warnings: + warning = WattTimeAPIWarning( + url=url, params=params, warning_message=warning_message + ) + self.raised_warnings.append(warning) + LOG.warning( + f"API Warning: {warning_message} | URL: {url} | Params: {params}" + ) - self._last_request_meta = j.get("meta", {}) + self._last_request_meta = meta return j @@ -409,7 +498,7 @@ def get_historical_csv( start, end = self._parse_dates(start, end) fp = out_dir / f"{region}_{signal_type}_{start.date()}_{end.date()}.csv" df.to_csv(fp, index=False) - print(f"file written to {fp}") + LOG.info(f"file written to {fp}") class WattTimeMyAccess(WattTimeBase): @@ -479,13 +568,16 @@ def _parse_historical_forecast_json( Returns: pd.DataFrame: A pandas DataFrame containing the parsed historical forecast data. """ - out = pd.DataFrame() - for json in json_list: - for entry in json.get("data", []): - _df = pd.json_normalize(entry, record_path=["forecast"]) - _df = _df.assign(generated_at=pd.to_datetime(entry["generated_at"])) - out = pd.concat([out, _df], ignore_index=True) - return out + data = [] + for j in json_list: + for gen_at in j["data"]: + for point_time in gen_at["forecast"]: + point_time["generated_at"] = gen_at["generated_at"] + data.append(point_time) + df = pd.DataFrame.from_records(data) + df["point_time"] = pd.to_datetime(df["point_time"]) + df["generated_at"] = pd.to_datetime(df["generated_at"]) + return df def get_forecast_json( self, @@ -709,7 +801,6 @@ def get_maps_json( class WattTimeMarginalFuelMix(WattTimeBase): - def get_fuel_mix_jsons( self, start: Union[str, datetime], @@ -731,7 +822,7 @@ def get_fuel_mix_jsons( chunks = self._get_chunks(start, end, chunk_size=timedelta(days=30)) # No model will default to the most recent model version available - if model: + if model is not None: params["model"] = model param_chunks = [{**params, "start": c[0], "end": c[1]} for c in chunks]