Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions src/kbase/auth/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,37 @@
# directly to the sync version - they will be overwritten. See the README for how to generate
# the sync client.

from cacheout.lru import LRUCache
from dataclasses import dataclass, fields
import httpx
import logging
from typing import Self
import time
from typing import Self, Callable
from uuid import UUID

from kbase.auth.exceptions import InvalidTokenError, InvalidUserError

# TODO PUBLISH make a pypi kbase org and publish there


@dataclass
class Token:
""" A KBase authentication token. """
id: UUID
""" The token's unique ID. """
user: str
""" The username of the user associated with the token. """
created: int
""" The time the token was created in epoch milliseconds. """
expires: int
""" The time the token expires in epoch milliseconds. """
cachefor: int
""" The time the token should be cached for in milliseconds. """
# TODO MFA add mfa info when the auth service supports it

_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)}


def _require_string(putative: str, name: str) -> str:
if not isinstance(putative, str) or not putative.strip():
raise ValueError(f"{name} is required and cannot be a whitespace only string")
Expand Down Expand Up @@ -52,12 +74,22 @@ class AsyncClient:
"""

@classmethod
async def create(cls, base_url: str) -> Self:
async def create(
cls,
base_url: str,
cache_max_size: int = 10000,
timer: Callable[[[]], int | float] = time.time
) -> Self:
"""
Create the client from the base url for the authentication service, for example
https://kbase.us/services/auth
Create the client.

base_url - the base url for the authentication service, for example
https://kbase.us/services/auth
cache_max_size - the maximum size of the token and user caches. When the cache size is
exceeded, the least recently used entries are evicted from the cache.
timer - the timer for the cache. Used for testing. Time unit must be seconds.
"""
cli = cls(base_url)
cli = cls(base_url, cache_max_size, timer)
try:
res = await cli._get(cli._base_url)
if res.get("servicename") != "Authentication Service":
Expand All @@ -66,15 +98,22 @@ async def create(cls, base_url: str) -> Self:
await cli.close()
raise
# TODO CLIENT look through the myriad of auth clients to see what functionality we need
# TODO CLIENT cache token & user using cachefor value from token
# TODO CLIENT cache user using cachefor value from token
# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all reads only
return cli

def __init__(self, base_url: str):
def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int | float]):
if not _require_string(base_url, "base_url").endswith("/"):
base_url += "/"
self._base_url = base_url
self._token_url = base_url + "api/V2/token"
self._me_url = base_url + "api/V2/me"
if cache_max_size < 1:
raise ValueError("cache_max_size must be > 0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, cacheout.LRUCache is fine with 0. That just removes them when they age out.
But it really doesn't matter, IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I don't want to allow infinite cache size

if not timer:
raise ValueError("timer is required")
self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer)
self._cli = httpx.AsyncClient()

async def __aenter__(self):
Expand All @@ -96,3 +135,25 @@ async def _get(self, url: str, headers=None):
async def service_version(self) -> str:
""" Return the version of the auth server. """
return (await self._get(self._base_url))["version"]

async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token:
"""
Get information about a KBase authentication token. This method caches the token;
further caching is unnecessary in most cases.

token - the token to query.
on_cache_miss - a function to call if a cache miss occurs.
"""
_require_string(token, "token")
tk = self._token_cache.get(token, default=False)
if tk:
return tk
if on_cache_miss:
on_cache_miss()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use case for on_cache_miss? Show a warning or something? Or testing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now just testing, but could be used for stats or something

res = await self._get(self._token_url, headers={"Authorization": token})
tk = Token(**{k: v for k, v in res.items() if k in _VALID_TOKEN_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
# Cleanest way to do this is update the auth2 service to allow setting it
# in test mode
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
return tk
74 changes: 67 additions & 7 deletions src/kbase/auth/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,37 @@
# directly to the sync version - they will be overwritten. See the README for how to generate
# the sync client.

from cacheout.lru import LRUCache
from dataclasses import dataclass, fields
import httpx
import logging
from typing import Self
import time
from typing import Self, Callable
from uuid import UUID

from kbase.auth.exceptions import InvalidTokenError, InvalidUserError

# TODO PUBLISH make a pypi kbase org and publish there


@dataclass
class Token:
""" A KBase authentication token. """
id: UUID
""" The token's unique ID. """
user: str
""" The username of the user associated with the token. """
created: int
""" The time the token was created in epoch milliseconds. """
expires: int
""" The time the token expires in epoch milliseconds. """
cachefor: int
""" The time the token should be cached for in milliseconds. """
# TODO MFA add mfa info when the auth service supports it

_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)}


def _require_string(putative: str, name: str) -> str:
if not isinstance(putative, str) or not putative.strip():
raise ValueError(f"{name} is required and cannot be a whitespace only string")
Expand Down Expand Up @@ -52,12 +74,21 @@ class Client:
"""

@classmethod
def create(cls, base_url: str) -> Self:
def create(
cls,
base_url: str,
cache_max_size: int = 10000,
timer: Callable[[[]], int | float] = time.time
) -> Self:
"""
Create the client from the base url for the authentication service, for example
https://kbase.us/services/auth
Create the client.

base_url - the base url for the authentication service, for example
https://kbase.us/services/auth
cache_max_size - the maximum size of the token and user caches.
timer - the timer for the cache. Used for testing. Time unit must be seconds.
"""
cli = cls(base_url)
cli = cls(base_url, cache_max_size, timer)
try:
res = cli._get(cli._base_url)
if res.get("servicename") != "Authentication Service":
Expand All @@ -66,15 +97,22 @@ def create(cls, base_url: str) -> Self:
cli.close()
raise
# TODO CLIENT look through the myriad of auth clients to see what functionality we need
# TODO CLIENT cache token & user using cachefor value from token
# TODO CLIENT cache user using cachefor value from token
# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all reads only
return cli

def __init__(self, base_url: str):
def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int | float]):
if not _require_string(base_url, "base_url").endswith("/"):
base_url += "/"
self._base_url = base_url
self._token_url = base_url + "api/V2/token"
self._me_url = base_url + "api/V2/me"
if cache_max_size < 1:
raise ValueError("cache_max_size must be > 0")
if not timer:
raise ValueError("timer is required")
self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer)
self._cli = httpx.Client()

def __enter__(self):
Expand All @@ -96,3 +134,25 @@ def _get(self, url: str, headers=None):
def service_version(self) -> str:
""" Return the version of the auth server. """
return (self._get(self._base_url))["version"]

def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token:
"""
Get information about a KBase authentication token. This method caches the token;
further caching is unnecessary in most cases.

token - the token to query.
on_cache_miss - a function to call if a cache miss occurs.
"""
_require_string(token, "token")
tk = self._token_cache.get(token, default=False)
if tk:
return tk
if on_cache_miss:
on_cache_miss()
res = self._get(self._token_url, headers={"Authorization": token})
tk = Token(**{k: v for k, v in res.items() if k in _VALID_TOKEN_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
# Cleanest way to do this is update the auth2 service to allow setting it
# in test mode
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
return tk
Loading