From 2c53c6495555922e2994331bf273b44691ab868f Mon Sep 17 00:00:00 2001 From: MrCreosote Date: Thu, 22 Jan 2026 11:55:52 -0800 Subject: [PATCH] Add token type field to returned Token class --- README.md | 2 +- RELEASE_NOTES.md | 1 + src/kbase/_auth/_async/client.py | 10 +++++++- src/kbase/_auth/_sync/client.py | 10 +++++++- src/kbase/_auth/models.py | 40 +++++++++++++++++++++++++++++++- test/conftest.py | 11 ++++++--- test/test_client.py | 24 +++++++++++++++++-- 7 files changed, 89 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 15e81fb..2c60347 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ from kbase.auth import AsyncKBaseAuthClient async with await AsyncKBaseAuthClient.create("https://ci.kbase.us/services/auth") as cli: print(await cli.get_token(token)) -Token(id='67797406-c6a3-4ee0-870d-976739dacd61', user='gaprice', mfa=, created=1755561300704, expires=1763337300704, cachefor=300000) +Token(id='fe042c54-0eb6-4bd6-a7cc-3c1c4d0228d2', user='gaprice', type=, mfa=, created=1763674630238, expires=1771450630238, cachefor=300000) ``` ### Get a user diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index c3d2686..e9bdd7c 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,6 +1,7 @@ ## 0.1.2 * Add MFA support. MFA will always be `UNKNOWN` if auth2 is not version 0.8.0 or above. +* Added the token type to the returned Token class. ## 0.1.1 diff --git a/src/kbase/_auth/_async/client.py b/src/kbase/_auth/_async/client.py index 50811c7..8415490 100644 --- a/src/kbase/_auth/_async/client.py +++ b/src/kbase/_auth/_async/client.py @@ -14,7 +14,14 @@ from typing import Self, Callable from kbase._auth.exceptions import InvalidTokenError, InvalidUserError -from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS, MFAStatus +from kbase._auth.models import ( + MFAStatus, + Token, + TokenType, + User, + VALID_TOKEN_FIELDS, + VALID_USER_FIELDS, +) # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all read only @@ -135,6 +142,7 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> res = await self._get(self._token_url, headers={"Authorization": token}) targs = {k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS} targs["mfa"] = MFAStatus.get_mfa(res.get("mfa")) + targs["type"] = TokenType.get_type(res["type"]) tk = Token(**targs) # TODO TEST later may want to add tests that change the cachefor value. self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) diff --git a/src/kbase/_auth/_sync/client.py b/src/kbase/_auth/_sync/client.py index 8cd8d01..d26c750 100644 --- a/src/kbase/_auth/_sync/client.py +++ b/src/kbase/_auth/_sync/client.py @@ -14,7 +14,14 @@ from typing import Self, Callable from kbase._auth.exceptions import InvalidTokenError, InvalidUserError -from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS, MFAStatus +from kbase._auth.models import ( + MFAStatus, + Token, + TokenType, + User, + VALID_TOKEN_FIELDS, + VALID_USER_FIELDS, +) # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all read only @@ -135,6 +142,7 @@ def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token res = self._get(self._token_url, headers={"Authorization": token}) targs = {k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS} targs["mfa"] = MFAStatus.get_mfa(res.get("mfa")) + targs["type"] = TokenType.get_type(res["type"]) tk = Token(**targs) # TODO TEST later may want to add tests that change the cachefor value. self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) diff --git a/src/kbase/_auth/models.py b/src/kbase/_auth/models.py index c64dc6b..3bbd985 100644 --- a/src/kbase/_auth/models.py +++ b/src/kbase/_auth/models.py @@ -44,6 +44,42 @@ def get_mfa(cls, mfa: str): } +class TokenType(Enum): + + LOGIN = 1 + """ A login token, generated by the user going through a login flow. """ + + AGENT = 2 + """ + An agent token generated from a login token. Intended for creation prior to running + a job to prevent token expiration during the job. + """ + + DEVELOPER = 3 + """ Longer lived tokens intended for developer use of the system via API calls. """ + + SERVICE = 4 + """ Very long lived tokens intended for services to communicate with each other. """ + + @classmethod + def get_type(cls, type_: str): + """ Given a string, get the token enum. """ + if not type_: + raise ValueError("type_ is required") + type_ = type_.lower() + if type_ not in _STR2TYPE: + raise ValueError("Unknown token type string: " + type_) + return _STR2TYPE[type_] + + +_STR2TYPE = { + "login": TokenType.LOGIN, + "agent": TokenType.AGENT, + "developer": TokenType.DEVELOPER, + "service": TokenType.SERVICE, +} + + @dataclass class Token: """ A KBase authentication token. """ @@ -51,6 +87,8 @@ class Token: """ The token's unique ID. """ user: str """ The username of the user associated with the token. """ + type: TokenType + """ The type of the token. """ mfa: MFAStatus """ The MFA status of the token. """ created: int @@ -59,7 +97,7 @@ class Token: """ The time the token expires in epoch milliseconds. """ cachefor: int """ The time the token should be cached for in milliseconds. """ - # TODO MFA add mfa info when the auth service supports it + VALID_TOKEN_FIELDS: set[str] = {f.name for f in fields(Token)} """ diff --git a/test/conftest.py b/test/conftest.py index 0774de2..965d9db 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -116,15 +116,20 @@ def add_roles(user: str, roles: list[str]): def auth_users(set_up_auth_roles) -> dict[str, str]: # username -> token ret = {} users = { - "user": "Used", "user_random1": "Unknown", "user_random2": "NotUsed", "user_all": None + "user": ("Used", "Login"), + "user_random1": ("Unknown", "Agent"), + "user_random2": ("NotUsed", "Dev"), + "user_all": (None, "Serv"), } - for u, mfa in users.items(): + for u, (mfa, tt) in users.items(): res = requests.post(f"{_AUTH_API}testmodeonly/user", json={"user": u, "display": "foo"}) res.raise_for_status() - reqjson = {"user": u, "type": "Login"} + reqjson = {"user": u, "type": tt} if mfa: reqjson["mfa"] = mfa res = requests.post(f"{_AUTH_API}testmodeonly/token", json=reqjson) + if not res.status_code == 200: + print(res.text) res.raise_for_status() ret[u] = res.json()["token"] diff --git a/test/test_client.py b/test/test_client.py index d55bcfb..9fd4adf 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -14,7 +14,7 @@ User, __version__ as ver, ) -from kbase._auth.models import MFAStatus +from kbase._auth.models import MFAStatus, TokenType def test_version(): @@ -34,6 +34,22 @@ def test_mfastatus_get_mfa_fail(): MFAStatus.get_mfa(mfa) +def test_tokentype_get_tokentype(): + assert TokenType.get_type("Login") == TokenType.LOGIN + assert TokenType.get_type("agent") == TokenType.AGENT + assert TokenType.get_type("deVelopEr") == TokenType.DEVELOPER + assert TokenType.get_type("Service") == TokenType.SERVICE + + +def test_tokentype_get_tokentype_fail(): + with pytest.raises(ValueError, match="type_ is required"): + TokenType.get_type(None) + + for t in ["foo", "loginated", "dunno"]: + with pytest.raises(ValueError, match=f"Unknown token type string: {t}"): + TokenType.get_type(t) + + async def _create_fail(url: str, expected: Exception, cachesize=1, timer=time.time): with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): KBaseAuthClient.create(url, cache_max_size=cachesize, timer=timer) @@ -120,6 +136,7 @@ async def test_get_token_basic(auth_users): assert t1 == Token( id=t1.id, user="user", + type=TokenType.LOGIN, cachefor=300000, created=t1.created, expires=t1.expires, @@ -132,6 +149,7 @@ async def test_get_token_basic(auth_users): assert t2 == Token( id=t2.id, user="user_random1", + type=TokenType.AGENT, cachefor=300000, created=t2.created, expires=t2.expires, @@ -141,9 +159,11 @@ async def test_get_token_basic(auth_users): assert time_close_to_now(t2.created, 10) assert t2.expires - t2.created == 3600000 - # for the remaining tokens we just check mfa + # for the remaining tokens we just check mfa & type assert t3.mfa == MFAStatus.NOT_USED + assert t3.type == TokenType.DEVELOPER assert t4.mfa == MFAStatus.UNKNOWN + assert t4.type == TokenType.SERVICE @pytest.mark.asyncio