From 39f58d125f339ce85794784e83521b1e077fa64d Mon Sep 17 00:00:00 2001 From: MrCreosote Date: Sun, 5 Oct 2025 16:37:52 -0700 Subject: [PATCH 1/2] Add a get_token method Adds a LRU cache for tokens --- src/kbase/auth/_async/client.py | 74 +++++++++++++-- src/kbase/auth/_sync/client.py | 74 +++++++++++++-- test/test_client.py | 160 +++++++++++++++++++++++++++++++- 3 files changed, 291 insertions(+), 17 deletions(-) diff --git a/src/kbase/auth/_async/client.py b/src/kbase/auth/_async/client.py index cc3b096..172d0f0 100644 --- a/src/kbase/auth/_async/client.py +++ b/src/kbase/auth/_async/client.py @@ -7,15 +7,37 @@ # directly to the sync version - they will be overwritten. See the README for how to generate # the sync client. +from cacheout.lru import LRUCache +from dataclasses import dataclass, fields import httpx import logging -from typing import Self +import time +from typing import Self, Callable +from uuid import UUID from kbase.auth.exceptions import InvalidTokenError, InvalidUserError # TODO PUBLISH make a pypi kbase org and publish there +@dataclass +class Token: + """ A KBase authentication token. """ + id: UUID + """ The token's unique ID. """ + user: str + """ The username of the user associated with the token. """ + created: int + """ The time the token was created in epoch milliseconds. """ + expires: int + """ The time the token expires in epoch milliseconds. """ + cachefor: int + """ The time the token should be cached for in milliseconds. """ + # TODO MFA add mfa info when the auth service supports it + +_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)} + + def _require_string(putative: str, name: str) -> str: if not isinstance(putative, str) or not putative.strip(): raise ValueError(f"{name} is required and cannot be a whitespace only string") @@ -52,12 +74,21 @@ class AsyncClient: """ @classmethod - async def create(cls, base_url: str) -> Self: + async def create( + cls, + base_url: str, + cache_max_size: int = 10000, + timer: Callable[[[]], int | float] = time.time + ) -> Self: """ - Create the client from the base url for the authentication service, for example - https://kbase.us/services/auth + Create the client. + + base_url - the base url for the authentication service, for example + https://kbase.us/services/auth + cache_max_size - the maximum size of the token and user caches. + timer - the timer for the cache. Used for testing. Time unit must be seconds. """ - cli = cls(base_url) + cli = cls(base_url, cache_max_size, timer) try: res = await cli._get(cli._base_url) if res.get("servicename") != "Authentication Service": @@ -66,15 +97,22 @@ async def create(cls, base_url: str) -> Self: await cli.close() raise # TODO CLIENT look through the myriad of auth clients to see what functionality we need - # TODO CLIENT cache token & user using cachefor value from token + # TODO CLIENT cache user using cachefor value from token # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all reads only return cli - def __init__(self, base_url: str): + def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int | float]): if not _require_string(base_url, "base_url").endswith("/"): base_url += "/" self._base_url = base_url + self._token_url = base_url + "api/V2/token" + self._me_url = base_url + "api/V2/me" + if cache_max_size < 1: + raise ValueError("cache_max_size must be > 0") + if not timer: + raise ValueError("timer is required") + self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer) self._cli = httpx.AsyncClient() async def __aenter__(self): @@ -96,3 +134,25 @@ async def _get(self, url: str, headers=None): async def service_version(self) -> str: """ Return the version of the auth server. """ return (await self._get(self._base_url))["version"] + + async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token: + """ + Get information about a KBase authentication token. This method caches the token; + further caching is unnecessary in most cases. + + token - the token to query. + on_cache_miss - a function to call if a cache miss occurs. + """ + _require_string(token, "token") + tk = self._token_cache.get(token, default=False) + if tk: + return tk + if on_cache_miss: + on_cache_miss() + res = await self._get(self._token_url, headers={"Authorization": token}) + tk = Token(**{k: v for k, v in res.items() if k in _VALID_TOKEN_FIELDS}) + # TODO TEST later may want to add tests that change the cachefor value. + # Cleanest way to do this is update the auth2 service to allow setting it + # in test mode + self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) + return tk diff --git a/src/kbase/auth/_sync/client.py b/src/kbase/auth/_sync/client.py index 8fc540c..e5e7af2 100644 --- a/src/kbase/auth/_sync/client.py +++ b/src/kbase/auth/_sync/client.py @@ -7,15 +7,37 @@ # directly to the sync version - they will be overwritten. See the README for how to generate # the sync client. +from cacheout.lru import LRUCache +from dataclasses import dataclass, fields import httpx import logging -from typing import Self +import time +from typing import Self, Callable +from uuid import UUID from kbase.auth.exceptions import InvalidTokenError, InvalidUserError # TODO PUBLISH make a pypi kbase org and publish there +@dataclass +class Token: + """ A KBase authentication token. """ + id: UUID + """ The token's unique ID. """ + user: str + """ The username of the user associated with the token. """ + created: int + """ The time the token was created in epoch milliseconds. """ + expires: int + """ The time the token expires in epoch milliseconds. """ + cachefor: int + """ The time the token should be cached for in milliseconds. """ + # TODO MFA add mfa info when the auth service supports it + +_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)} + + def _require_string(putative: str, name: str) -> str: if not isinstance(putative, str) or not putative.strip(): raise ValueError(f"{name} is required and cannot be a whitespace only string") @@ -52,12 +74,21 @@ class Client: """ @classmethod - def create(cls, base_url: str) -> Self: + def create( + cls, + base_url: str, + cache_max_size: int = 10000, + timer: Callable[[[]], int | float] = time.time + ) -> Self: """ - Create the client from the base url for the authentication service, for example - https://kbase.us/services/auth + Create the client. + + base_url - the base url for the authentication service, for example + https://kbase.us/services/auth + cache_max_size - the maximum size of the token and user caches. + timer - the timer for the cache. Used for testing. Time unit must be seconds. """ - cli = cls(base_url) + cli = cls(base_url, cache_max_size, timer) try: res = cli._get(cli._base_url) if res.get("servicename") != "Authentication Service": @@ -66,15 +97,22 @@ def create(cls, base_url: str) -> Self: cli.close() raise # TODO CLIENT look through the myriad of auth clients to see what functionality we need - # TODO CLIENT cache token & user using cachefor value from token + # TODO CLIENT cache user using cachefor value from token # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all reads only return cli - def __init__(self, base_url: str): + def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int | float]): if not _require_string(base_url, "base_url").endswith("/"): base_url += "/" self._base_url = base_url + self._token_url = base_url + "api/V2/token" + self._me_url = base_url + "api/V2/me" + if cache_max_size < 1: + raise ValueError("cache_max_size must be > 0") + if not timer: + raise ValueError("timer is required") + self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer) self._cli = httpx.Client() def __enter__(self): @@ -96,3 +134,25 @@ def _get(self, url: str, headers=None): def service_version(self) -> str: """ Return the version of the auth server. """ return (self._get(self._base_url))["version"] + + def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token: + """ + Get information about a KBase authentication token. This method caches the token; + further caching is unnecessary in most cases. + + token - the token to query. + on_cache_miss - a function to call if a cache miss occurs. + """ + _require_string(token, "token") + tk = self._token_cache.get(token, default=False) + if tk: + return tk + if on_cache_miss: + on_cache_miss() + res = self._get(self._token_url, headers={"Authorization": token}) + tk = Token(**{k: v for k, v in res.items() if k in _VALID_TOKEN_FIELDS}) + # TODO TEST later may want to add tests that change the cachefor value. + # Cleanest way to do this is update the auth2 service to allow setting it + # in test mode + self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) + return tk diff --git a/test/test_client.py b/test/test_client.py index 9edb88c..e7f9d35 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1,15 +1,19 @@ import pytest +import time +from unittest.mock import Mock +import uuid from conftest import AUTH_URL, AUTH_VERSION from kbase.auth.client import KBaseAuthClient, AsyncKBaseAuthClient +from kbase.auth.exceptions import InvalidTokenError -async def _create_fail(url: str, expected: Exception): +async def _create_fail(url: str, expected: Exception, cachesize=1, timer=time.time): with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): - KBaseAuthClient.create(url) + KBaseAuthClient.create(url, cache_max_size=cachesize, timer=timer) with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): - await AsyncKBaseAuthClient.create(url) + await AsyncKBaseAuthClient.create(url, cache_max_size=cachesize, timer=timer) @pytest.mark.asyncio @@ -17,6 +21,11 @@ async def test_create_fail(): err = "base_url is required and cannot be a whitespace only string" for u in [None, " \t ", 3]: await _create_fail(u, ValueError(err)) + err = "cache_max_size must be > 0" + for t in [0, -1, -1000000]: + await _create_fail("https://ci.kbase.us/service/auth", ValueError(err), cachesize=t) + err = "timer is required" + await _create_fail("https://ci.kbase.us/service/auth", ValueError(err), timer=None) @pytest.mark.asyncio @@ -59,3 +68,148 @@ async def test_service_version_with_context_manager(auth_users): assert cli.service_version() == AUTH_VERSION async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli: assert await cli.service_version() == AUTH_VERSION + + +def is_valid_uuid(u): + try: + uuid.UUID(u) + return True + except ValueError: + return False + + +def time_close_to_now(epoch_ms: int, tolerance_sec: float) -> bool: + now_ms = int(time.time() * 1000) + return abs(now_ms - epoch_ms) <= tolerance_sec * 1000 + + +@pytest.mark.asyncio +async def test_get_token_basic(auth_users): + with KBaseAuthClient.create(AUTH_URL) as cli: + t1 = cli.get_token(auth_users["user"]) + async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli: + t2 = await cli.get_token(auth_users["user_random1"]) + + assert is_valid_uuid(t1.id) + assert t1.user == "user" + assert t1.cachefor == 300000 + assert time_close_to_now(t1.created, 10) + assert t1.expires - t1.created == 3600000 + + assert is_valid_uuid(t2.id) + assert t2.user == "user_random1" + assert t2.cachefor == 300000 + assert time_close_to_now(t2.created, 10) + assert t2.expires - t2.created == 3600000 + + +@pytest.mark.asyncio +async def test_get_token_basic_fail(auth_users): + err = "token is required and cannot be a whitespace only string" + await _get_token_basic_fail(None, ValueError(err)) + await _get_token_basic_fail(" \t ", ValueError(err)) + err = "KBase auth server reported token is invalid." + await _get_token_basic_fail("superfake", InvalidTokenError(err)) + + +async def _get_token_basic_fail(token: str, expected: Exception): + with KBaseAuthClient.create(AUTH_URL) as cli: + with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): + cli.get_token(token) + async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli: + with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): + await cli.get_token(token) + + +@pytest.mark.asyncio +async def test_get_token_cache_evict_on_size(auth_users): + with KBaseAuthClient.create(AUTH_URL, cache_max_size=3) as cli: + cachemiss = Mock() + # fill the cache + t1 = cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + t2 = cli.get_token(auth_users["user_random1"], on_cache_miss=cachemiss) + t3 = cli.get_token(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + # check tokens in cache + tt1 = cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + tt2 = cli.get_token(auth_users["user_random1"], on_cache_miss=cachemiss) + tt3 = cli.get_token(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + assert tt1 == t1 + assert tt2 == t2 + assert tt3 == t3 + # Force an eviction + cli.get_token(auth_users["user_all"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 4 + # Check user was evicted + ttt1 = cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 5 + assert ttt1 == t1 + + async with await AsyncKBaseAuthClient.create(AUTH_URL, cache_max_size=3) as cli: + cachemiss = Mock() + # fill the cache + t1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + t2 = await cli.get_token(auth_users["user_random1"], on_cache_miss=cachemiss) + t3 = await cli.get_token(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + # check tokens in cache + tt1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + tt2 = await cli.get_token(auth_users["user_random1"], on_cache_miss=cachemiss) + tt3 = await cli.get_token(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + assert tt1 == t1 + assert tt2 == t2 + assert tt3 == t3 + # Force an eviction + await cli.get_token(auth_users["user_all"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 4 + # Check user was evicted + ttt1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 5 + assert ttt1 == t1 + + +# easier to understand than a mock with an array of times +class FakeTimer: + def __init__(self): + self.current = 1000 # arbitrary start time + def __call__(self) -> float: + return self.current + def advance(self, seconds: float): + self.current += seconds + + +@pytest.mark.asyncio +async def test_get_token_cache_evict_on_time(auth_users): + timer = FakeTimer() + with KBaseAuthClient.create(AUTH_URL, timer=timer) as cli: + cachemiss = Mock() + t1 = cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + # TODO TEST auth2 always returns 300000 ms for cachefor. Update testmode to allow + # setting different values and test here + timer.advance(299) + tt1 = cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + assert tt1 == t1 + timer.advance(2) + ttt1 = cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 2 + assert ttt1 == t1 + + timer = FakeTimer() + async with await AsyncKBaseAuthClient.create(AUTH_URL, timer=timer) as cli: + cachemiss = Mock() + t1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + # TODO TEST auth2 always returns 300000 ms for cachefor. Update testmode to allow + # setting different values and test here + timer.advance(299) + tt1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + assert tt1 == t1 + timer.advance(2) + ttt1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 2 + assert ttt1 == t1 From fa5f80e06335e126156fc84c90bc4125565941d2 Mon Sep 17 00:00:00 2001 From: MrCreosote Date: Thu, 9 Oct 2025 14:09:04 -0700 Subject: [PATCH 2/2] Add docs re cache eviction --- src/kbase/auth/_async/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/kbase/auth/_async/client.py b/src/kbase/auth/_async/client.py index 172d0f0..a5af7ea 100644 --- a/src/kbase/auth/_async/client.py +++ b/src/kbase/auth/_async/client.py @@ -85,7 +85,8 @@ async def create( base_url - the base url for the authentication service, for example https://kbase.us/services/auth - cache_max_size - the maximum size of the token and user caches. + cache_max_size - the maximum size of the token and user caches. When the cache size is + exceeded, the least recently used entries are evicted from the cache. timer - the timer for the cache. Used for testing. Time unit must be seconds. """ cli = cls(base_url, cache_max_size, timer)