From b99bff5ade28516a3cfd2920fc9ad5ca7c077d6e Mon Sep 17 00:00:00 2001 From: MrCreosote Date: Fri, 21 Nov 2025 17:35:07 -0800 Subject: [PATCH] Add MFA support --- README.md | 2 +- RELEASE_NOTES.md | 4 ++++ docker-compose.yaml | 2 +- pyproject.toml | 2 +- src/kbase/_auth/_async/client.py | 6 +++-- src/kbase/_auth/_sync/client.py | 6 +++-- src/kbase/_auth/models.py | 41 ++++++++++++++++++++++++++++++++ src/kbase/auth.py | 2 +- test/conftest.py | 14 ++++++++--- test/test_client.py | 38 +++++++++++++++++++++++++---- uv.lock | 2 +- 11 files changed, 103 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 46b2f99..15e81fb 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ from kbase.auth import AsyncKBaseAuthClient async with await AsyncKBaseAuthClient.create("https://ci.kbase.us/services/auth") as cli: print(await cli.get_token(token)) -Token(id='67797406-c6a3-4ee0-870d-976739dacd61', user='gaprice', created=1755561300704, expires=1763337300704, cachefor=300000) +Token(id='67797406-c6a3-4ee0-870d-976739dacd61', user='gaprice', mfa=, created=1755561300704, expires=1763337300704, cachefor=300000) ``` ### Get a user diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 48bdd47..c3d2686 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,3 +1,7 @@ +## 0.1.2 + +* Add MFA support. MFA will always be `UNKNOWN` if auth2 is not version 0.8.0 or above. + ## 0.1.1 * Update README with install instructions diff --git a/docker-compose.yaml b/docker-compose.yaml index 4d224ea..6dc1aef 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,7 +1,7 @@ services: auth: - image: ghcr.io/kbase/auth2:0.7.1 + image: ghcr.io/kbase/auth2:0.8.0 platform: linux/amd64 ports: - 50001:8080 diff --git a/pyproject.toml b/pyproject.toml index 3491006..685bd54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kbase-auth-client" -version = "0.1.1" +version = "0.1.2" description = "Client for the KBase Authentication Service" readme = "README.md" authors = [{ name = "KBase Development Team" }] diff --git a/src/kbase/_auth/_async/client.py b/src/kbase/_auth/_async/client.py index 63d5b93..50811c7 100644 --- a/src/kbase/_auth/_async/client.py +++ b/src/kbase/_auth/_async/client.py @@ -14,7 +14,7 @@ from typing import Self, Callable from kbase._auth.exceptions import InvalidTokenError, InvalidUserError -from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS +from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS, MFAStatus # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all read only @@ -133,7 +133,9 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> if on_cache_miss: on_cache_miss() res = await self._get(self._token_url, headers={"Authorization": token}) - tk = Token(**{k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS}) + targs = {k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS} + targs["mfa"] = MFAStatus.get_mfa(res.get("mfa")) + tk = Token(**targs) # TODO TEST later may want to add tests that change the cachefor value. self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) return tk diff --git a/src/kbase/_auth/_sync/client.py b/src/kbase/_auth/_sync/client.py index 8c27868..8cd8d01 100644 --- a/src/kbase/_auth/_sync/client.py +++ b/src/kbase/_auth/_sync/client.py @@ -14,7 +14,7 @@ from typing import Self, Callable from kbase._auth.exceptions import InvalidTokenError, InvalidUserError -from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS +from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS, MFAStatus # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all read only @@ -133,7 +133,9 @@ def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token if on_cache_miss: on_cache_miss() res = self._get(self._token_url, headers={"Authorization": token}) - tk = Token(**{k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS}) + targs = {k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS} + targs["mfa"] = MFAStatus.get_mfa(res.get("mfa")) + tk = Token(**targs) # TODO TEST later may want to add tests that change the cachefor value. self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) return tk diff --git a/src/kbase/_auth/models.py b/src/kbase/_auth/models.py index 8e45c89..c64dc6b 100644 --- a/src/kbase/_auth/models.py +++ b/src/kbase/_auth/models.py @@ -3,8 +3,47 @@ """ from dataclasses import dataclass, fields +from enum import Enum from uuid import UUID + +class MFAStatus(Enum): + + USED = 1 + """ The user used MFA when logging in. """ + + NOT_USED = 2 + """ The user chose not to use MFA when logging in. """ + + UNKNOWN = 3 + """ + Either + * The 3rd party identity supplier does not support MFA or + * The 3rd party identity supplier was configured not to use MFA or + * The 3rd party identity supplier did not provide enough information to determine if + MFA was used or + * MFA is not applicable to the data (e.g. token types other than Login tokens). + + """ + + @classmethod + def get_mfa(cls, mfa: str): + """ Given a string, get the mfa enum. """ + if not mfa: + return cls.UNKNOWN + mfa = mfa.lower() + if mfa not in _STR2MFA: + raise ValueError("Unknown MFA string: " + mfa) + return _STR2MFA[mfa] + + +_STR2MFA = { + "used": MFAStatus.USED, + "notused": MFAStatus.NOT_USED, + "unknown": MFAStatus.UNKNOWN, +} + + @dataclass class Token: """ A KBase authentication token. """ @@ -12,6 +51,8 @@ class Token: """ The token's unique ID. """ user: str """ The username of the user associated with the token. """ + mfa: MFAStatus + """ The MFA status of the token. """ created: int """ The time the token was created in epoch milliseconds. """ expires: int diff --git a/src/kbase/auth.py b/src/kbase/auth.py index 51ab728..c03e981 100644 --- a/src/kbase/auth.py +++ b/src/kbase/auth.py @@ -15,4 +15,4 @@ ) -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/test/conftest.py b/test/conftest.py index ee1e8ac..0774de2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -18,7 +18,7 @@ AUTH_URL = "http://localhost:50001/testmode" _AUTH_API = AUTH_URL + "/api/V2/" -AUTH_VERSION = "0.7.1" +AUTH_VERSION = "0.8.0" SOME_RANDOM_ROLE1 = "random1" SOME_RANDOM_ROLE2 = "random2" @@ -107,16 +107,24 @@ def add_roles(user: str, roles: list[str]): res = requests.put( f"{_AUTH_API}testmodeonly/userroles", json={"user": user, "customroles": roles}, ) + if not res.status_code == 200: + print(res.text) res.raise_for_status() @pytest.fixture(scope="session", autouse=True) def auth_users(set_up_auth_roles) -> dict[str, str]: # username -> token ret = {} - for u in ["user", "user_random1", "user_random2", "user_all"]: + users = { + "user": "Used", "user_random1": "Unknown", "user_random2": "NotUsed", "user_all": None + } + for u, mfa in users.items(): res = requests.post(f"{_AUTH_API}testmodeonly/user", json={"user": u, "display": "foo"}) res.raise_for_status() - res = requests.post(f"{_AUTH_API}testmodeonly/token", json={"user": u, "type": "Dev"}) + reqjson = {"user": u, "type": "Login"} + if mfa: + reqjson["mfa"] = mfa + res = requests.post(f"{_AUTH_API}testmodeonly/token", json=reqjson) res.raise_for_status() ret[u] = res.json()["token"] diff --git a/test/test_client.py b/test/test_client.py index 33cf6db..d55bcfb 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -14,10 +14,24 @@ User, __version__ as ver, ) +from kbase._auth.models import MFAStatus def test_version(): - assert ver == "0.1.1" + assert ver == "0.1.2" + + +def test_mfastatus_get_mfa(): + assert MFAStatus.get_mfa(None) == MFAStatus.UNKNOWN + assert MFAStatus.get_mfa("Used") == MFAStatus.USED + assert MFAStatus.get_mfa("UnKnoWn") == MFAStatus.UNKNOWN + assert MFAStatus.get_mfa("notused") == MFAStatus.NOT_USED + + +def test_mfastatus_get_mfa_fail(): + for mfa in ["foo", "useded", "dunno"]: + with pytest.raises(ValueError, match=f"Unknown MFA string: {mfa}"): + MFAStatus.get_mfa(mfa) async def _create_fail(url: str, expected: Exception, cachesize=1, timer=time.time): @@ -100,20 +114,36 @@ async def test_get_token_basic(auth_users): t1 = cli.get_token(auth_users["user"]) async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli: t2 = await cli.get_token(auth_users["user_random1"]) - + t3 = await cli.get_token(auth_users["user_random2"]) + t4 = await cli.get_token(auth_users["user_all"]) + assert t1 == Token( - id=t1.id, user="user", cachefor=300000, created=t1.created, expires=t1.expires + id=t1.id, + user="user", + cachefor=300000, + created=t1.created, + expires=t1.expires, + mfa=MFAStatus.USED, ) assert is_valid_uuid(t1.id) assert time_close_to_now(t1.created, 10) assert t1.expires - t1.created == 3600000 assert t2 == Token( - id=t2.id, user="user_random1", cachefor=300000, created=t2.created, expires=t2.expires + id=t2.id, + user="user_random1", + cachefor=300000, + created=t2.created, + expires=t2.expires, + mfa=MFAStatus.UNKNOWN, ) assert is_valid_uuid(t2.id) assert time_close_to_now(t2.created, 10) assert t2.expires - t2.created == 3600000 + + # for the remaining tokens we just check mfa + assert t3.mfa == MFAStatus.NOT_USED + assert t4.mfa == MFAStatus.UNKNOWN @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index ab7b5ba..a33c62b 100644 --- a/uv.lock +++ b/uv.lock @@ -297,7 +297,7 @@ wheels = [ [[package]] name = "kbase-auth-client" -version = "0.1.1" +version = "0.1.2" source = { editable = "." } dependencies = [ { name = "cacheout" },