Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scripts/process_unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def main():

rules = [
unasync.Rule(
fromdir="/src/kbase/auth/_async/",
todir="/src/kbase/auth/_sync/",
fromdir="/src/kbase/_auth/_async/",
todir="/src/kbase/_auth/_sync/",
additional_replacements=additional_replacements,
),
]

filepaths = [
str(Path(__file__).parent.parent / "src" / "kbase" / "auth" / "_async" / "client.py")
str(Path(__file__).parent.parent / "src" / "kbase" / "_auth" / "_async" / "client.py")
]

unasync.unasync_files(filepaths, rules)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,20 @@
# the sync client.

from cacheout.lru import LRUCache
from dataclasses import dataclass, fields
import httpx
import logging
import time
from typing import Self, Callable
from uuid import UUID

from kbase.auth.exceptions import InvalidTokenError, InvalidUserError
from kbase._auth.exceptions import InvalidTokenError, InvalidUserError
from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS

# TODO PUBLISH make a pypi kbase org and publish there
# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all read only
# TODO NOW CODE make a kbase/auth.py module, move other code into _auth, and import everything
# TODO NOW CODE move Token and User into a common class
# We might want to expand exceptions to include the request ID for debugging purposes


@dataclass
class Token:
""" A KBase authentication token. """
id: UUID
""" The token's unique ID. """
user: str
""" The username of the user associated with the token. """
created: int
""" The time the token was created in epoch milliseconds. """
expires: int
""" The time the token expires in epoch milliseconds. """
cachefor: int
""" The time the token should be cached for in milliseconds. """
# TODO MFA add mfa info when the auth service supports it

_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)}


@dataclass
class User:
""" Information about a KBase user. """
user: str
""" The username of the user associated with the token. """
customroles: list[str]
""" The Auth2 custom roles the user possesses. """
# Not seeing any other fields that are generally useful right now
# Don't really want to expose idents unless there's a very good reason


_VALID_USER_FIELDS = {f.name for f in fields(User)}


def _require_string(putative: str, name: str) -> str:
if not isinstance(putative, str) or not putative.strip():
raise ValueError(f"{name} is required and cannot be a whitespace only string")
Expand Down Expand Up @@ -169,7 +134,7 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) ->
if on_cache_miss:
on_cache_miss()
res = await self._get(self._token_url, headers={"Authorization": token})
tk = Token(**{k: v for k, v in res.items() if k in _VALID_TOKEN_FIELDS})
tk = Token(**{k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
return tk
Expand All @@ -193,7 +158,7 @@ async def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) ->
on_cache_miss()
tk = await self.get_token(token)
res = await self._get(self._me_url, headers={"Authorization": token})
u = User(**{k: v for k, v in res.items() if k in _VALID_USER_FIELDS})
u = User(**{k: v for k, v in res.items() if k in VALID_USER_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
self._user_cache.set(token, u, ttl=tk.cachefor / 1000)
return u
Expand Down
43 changes: 4 additions & 39 deletions src/kbase/auth/_sync/client.py → src/kbase/_auth/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,20 @@
# the sync client.

from cacheout.lru import LRUCache
from dataclasses import dataclass, fields
import httpx
import logging
import time
from typing import Self, Callable
from uuid import UUID

from kbase.auth.exceptions import InvalidTokenError, InvalidUserError
from kbase._auth.exceptions import InvalidTokenError, InvalidUserError
from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS

# TODO PUBLISH make a pypi kbase org and publish there
# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all read only
# TODO NOW CODE make a kbase/auth.py module, move other code into _auth, and import everything
# TODO NOW CODE move Token and User into a common class
# We might want to expand exceptions to include the request ID for debugging purposes


@dataclass
class Token:
""" A KBase authentication token. """
id: UUID
""" The token's unique ID. """
user: str
""" The username of the user associated with the token. """
created: int
""" The time the token was created in epoch milliseconds. """
expires: int
""" The time the token expires in epoch milliseconds. """
cachefor: int
""" The time the token should be cached for in milliseconds. """
# TODO MFA add mfa info when the auth service supports it

_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)}


@dataclass
class User:
""" Information about a KBase user. """
user: str
""" The username of the user associated with the token. """
customroles: list[str]
""" The Auth2 custom roles the user possesses. """
# Not seeing any other fields that are generally useful right now
# Don't really want to expose idents unless there's a very good reason


_VALID_USER_FIELDS = {f.name for f in fields(User)}


def _require_string(putative: str, name: str) -> str:
if not isinstance(putative, str) or not putative.strip():
raise ValueError(f"{name} is required and cannot be a whitespace only string")
Expand Down Expand Up @@ -169,7 +134,7 @@ def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token
if on_cache_miss:
on_cache_miss()
res = self._get(self._token_url, headers={"Authorization": token})
tk = Token(**{k: v for k, v in res.items() if k in _VALID_TOKEN_FIELDS})
tk = Token(**{k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
return tk
Expand All @@ -193,7 +158,7 @@ def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) -> User:
on_cache_miss()
tk = self.get_token(token)
res = self._get(self._me_url, headers={"Authorization": token})
u = User(**{k: v for k, v in res.items() if k in _VALID_USER_FIELDS})
u = User(**{k: v for k, v in res.items() if k in VALID_USER_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
self._user_cache.set(token, u, ttl=tk.cachefor / 1000)
return u
Expand Down
File renamed without changes.
43 changes: 43 additions & 0 deletions src/kbase/_auth/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Data classes for the clients.
"""

from dataclasses import dataclass, fields
from uuid import UUID

@dataclass
class Token:
""" A KBase authentication token. """
id: UUID
""" The token's unique ID. """
user: str
""" The username of the user associated with the token. """
created: int
""" The time the token was created in epoch milliseconds. """
expires: int
""" The time the token expires in epoch milliseconds. """
cachefor: int
""" The time the token should be cached for in milliseconds. """
# TODO MFA add mfa info when the auth service supports it

VALID_TOKEN_FIELDS: set[str] = {f.name for f in fields(Token)}
"""
The field names for the Token dataclass.
"""


@dataclass
class User:
""" Information about a KBase user. """
user: str
""" The username of the user associated with the token. """
customroles: list[str]
""" The Auth2 custom roles the user possesses. """
# Not seeing any other fields that are generally useful right now
# Don't really want to expose idents unless there's a very good reason


VALID_USER_FIELDS: set[str] = {f.name for f in fields(User)}
"""
The field names for the user dataclass.
"""
15 changes: 15 additions & 0 deletions src/kbase/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
The aync and sync versions of the KBase Auth Client.
"""

from kbase._auth._async.client import AsyncKBaseAuthClient # @UnusedImport
from kbase._auth._sync.client import KBaseAuthClient # @UnusedImport
from kbase._auth.exceptions import (
AuthenticationError, # @UnusedImport
InvalidTokenError, # @UnusedImport
InvalidUserError, # @UnusedImport
)
from kbase._auth.models import (
Token, # @UnusedImport
User, # @UnusedImport
)
6 changes: 0 additions & 6 deletions src/kbase/auth/client.py

This file was deleted.

35 changes: 18 additions & 17 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@

from conftest import AUTH_URL, AUTH_VERSION

from kbase.auth.client import KBaseAuthClient, AsyncKBaseAuthClient
from kbase.auth.exceptions import InvalidTokenError, InvalidUserError
from kbase.auth import (
AsyncKBaseAuthClient,
InvalidTokenError,
InvalidUserError,
KBaseAuthClient,
Token,
User,
)


async def _create_fail(url: str, expected: Exception, cachesize=1, timer=time.time):
Expand Down Expand Up @@ -90,15 +96,17 @@ async def test_get_token_basic(auth_users):
async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli:
t2 = await cli.get_token(auth_users["user_random1"])

assert t1 == Token(
id=t1.id, user="user", cachefor=300000, created=t1.created, expires=t1.expires
)
assert is_valid_uuid(t1.id)
assert t1.user == "user"
assert t1.cachefor == 300000
assert time_close_to_now(t1.created, 10)
assert t1.expires - t1.created == 3600000

assert t2 == Token(
id=t2.id, user="user_random1", cachefor=300000, created=t2.created, expires=t2.expires
)
assert is_valid_uuid(t2.id)
assert t2.user == "user_random1"
assert t2.cachefor == 300000
assert time_close_to_now(t2.created, 10)
assert t2.expires - t2.created == 3600000

Expand Down Expand Up @@ -222,17 +230,10 @@ async def test_get_user_basic(auth_users):
u3 = await cli.get_user(auth_users["user_random1"])
u4 = await cli.get_user(auth_users["user_random2"])

assert u1.user == "user"
assert u1.customroles == []

assert u2.user == "user_all"
assert u2.customroles == ["random1", "random2"]

assert u3.user == "user_random1"
assert u3.customroles == ["random1"]

assert u4.user == "user_random2"
assert u4.customroles == ["random2"]
assert u1 == User(user="user", customroles=[])
assert u2 == User(user="user_all", customroles=["random1", "random2"])
assert u3 == User(user="user_random1", customroles=["random1"])
assert u4 == User(user="user_random2", customroles=["random2"])


@pytest.mark.asyncio
Expand Down