Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ from kbase.auth import AsyncKBaseAuthClient

async with await AsyncKBaseAuthClient.create("https://ci.kbase.us/services/auth") as cli:
print(await cli.get_token(token))
Token(id='67797406-c6a3-4ee0-870d-976739dacd61', user='gaprice', mfa=<MFAStatus.UNKNOWN: 3>, created=1755561300704, expires=1763337300704, cachefor=300000)
Token(id='fe042c54-0eb6-4bd6-a7cc-3c1c4d0228d2', user='gaprice', type=<TokenType.DEVELOPER: 3>, mfa=<MFAStatus.UNKNOWN: 3>, created=1763674630238, expires=1771450630238, cachefor=300000)
```

### Get a user
Expand Down
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.1.2

* Add MFA support. MFA will always be `UNKNOWN` if auth2 is not version 0.8.0 or above.
* Added the token type to the returned Token class.

## 0.1.1

Expand Down
10 changes: 9 additions & 1 deletion src/kbase/_auth/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
from typing import Self, Callable

from kbase._auth.exceptions import InvalidTokenError, InvalidUserError
from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS, MFAStatus
from kbase._auth.models import (
MFAStatus,
Token,
TokenType,
User,
VALID_TOKEN_FIELDS,
VALID_USER_FIELDS,
)

# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all read only
Expand Down Expand Up @@ -135,6 +142,7 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) ->
res = await self._get(self._token_url, headers={"Authorization": token})
targs = {k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS}
targs["mfa"] = MFAStatus.get_mfa(res.get("mfa"))
targs["type"] = TokenType.get_type(res["type"])
tk = Token(**targs)
# TODO TEST later may want to add tests that change the cachefor value.
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
Expand Down
10 changes: 9 additions & 1 deletion src/kbase/_auth/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
from typing import Self, Callable

from kbase._auth.exceptions import InvalidTokenError, InvalidUserError
from kbase._auth.models import Token, User, VALID_TOKEN_FIELDS, VALID_USER_FIELDS, MFAStatus
from kbase._auth.models import (
MFAStatus,
Token,
TokenType,
User,
VALID_TOKEN_FIELDS,
VALID_USER_FIELDS,
)

# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all read only
Expand Down Expand Up @@ -135,6 +142,7 @@ def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token
res = self._get(self._token_url, headers={"Authorization": token})
targs = {k: v for k, v in res.items() if k in VALID_TOKEN_FIELDS}
targs["mfa"] = MFAStatus.get_mfa(res.get("mfa"))
targs["type"] = TokenType.get_type(res["type"])
tk = Token(**targs)
# TODO TEST later may want to add tests that change the cachefor value.
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
Expand Down
40 changes: 39 additions & 1 deletion src/kbase/_auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,51 @@ def get_mfa(cls, mfa: str):
}


class TokenType(Enum):

LOGIN = 1
""" A login token, generated by the user going through a login flow. """

AGENT = 2
"""
An agent token generated from a login token. Intended for creation prior to running
a job to prevent token expiration during the job.
"""

DEVELOPER = 3
""" Longer lived tokens intended for developer use of the system via API calls. """

SERVICE = 4
""" Very long lived tokens intended for services to communicate with each other. """

@classmethod
def get_type(cls, type_: str):
""" Given a string, get the token enum. """
if not type_:
raise ValueError("type_ is required")
type_ = type_.lower()
if type_ not in _STR2TYPE:
raise ValueError("Unknown token type string: " + type_)
return _STR2TYPE[type_]


_STR2TYPE = {
"login": TokenType.LOGIN,
"agent": TokenType.AGENT,
"developer": TokenType.DEVELOPER,
"service": TokenType.SERVICE,
}


Comment on lines +75 to +82

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be equivalent to _STR2TYPE = {t.name.lower(): t for t in TokenType}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm going to leave it as is just because the mfa enum is done the same way and the input names aren't just lower case of the MFA enum names

@dataclass
class Token:
""" A KBase authentication token. """
id: UUID
""" The token's unique ID. """
user: str
""" The username of the user associated with the token. """
type: TokenType
""" The type of the token. """
mfa: MFAStatus
""" The MFA status of the token. """
created: int
Expand All @@ -59,7 +97,7 @@ class Token:
""" The time the token expires in epoch milliseconds. """
cachefor: int
""" The time the token should be cached for in milliseconds. """
# TODO MFA add mfa info when the auth service supports it


VALID_TOKEN_FIELDS: set[str] = {f.name for f in fields(Token)}
"""
Expand Down
11 changes: 8 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,20 @@ def add_roles(user: str, roles: list[str]):
def auth_users(set_up_auth_roles) -> dict[str, str]: # username -> token
ret = {}
users = {
"user": "Used", "user_random1": "Unknown", "user_random2": "NotUsed", "user_all": None
"user": ("Used", "Login"),
"user_random1": ("Unknown", "Agent"),
"user_random2": ("NotUsed", "Dev"),
"user_all": (None, "Serv"),
}
for u, mfa in users.items():
for u, (mfa, tt) in users.items():
res = requests.post(f"{_AUTH_API}testmodeonly/user", json={"user": u, "display": "foo"})
res.raise_for_status()
reqjson = {"user": u, "type": "Login"}
reqjson = {"user": u, "type": tt}
if mfa:
reqjson["mfa"] = mfa
res = requests.post(f"{_AUTH_API}testmodeonly/token", json=reqjson)
if not res.status_code == 200:
print(res.text)
res.raise_for_status()
ret[u] = res.json()["token"]

Expand Down
24 changes: 22 additions & 2 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
User,
__version__ as ver,
)
from kbase._auth.models import MFAStatus
from kbase._auth.models import MFAStatus, TokenType


def test_version():
Expand All @@ -34,6 +34,22 @@ def test_mfastatus_get_mfa_fail():
MFAStatus.get_mfa(mfa)


def test_tokentype_get_tokentype():
assert TokenType.get_type("Login") == TokenType.LOGIN
assert TokenType.get_type("agent") == TokenType.AGENT
assert TokenType.get_type("deVelopEr") == TokenType.DEVELOPER
assert TokenType.get_type("Service") == TokenType.SERVICE


def test_tokentype_get_tokentype_fail():
with pytest.raises(ValueError, match="type_ is required"):
TokenType.get_type(None)

for t in ["foo", "loginated", "dunno"]:
with pytest.raises(ValueError, match=f"Unknown token type string: {t}"):
TokenType.get_type(t)


async def _create_fail(url: str, expected: Exception, cachesize=1, timer=time.time):
with pytest.raises(type(expected), match=f"^{expected.args[0]}$"):
KBaseAuthClient.create(url, cache_max_size=cachesize, timer=timer)
Expand Down Expand Up @@ -120,6 +136,7 @@ async def test_get_token_basic(auth_users):
assert t1 == Token(
id=t1.id,
user="user",
type=TokenType.LOGIN,
cachefor=300000,
created=t1.created,
expires=t1.expires,
Expand All @@ -132,6 +149,7 @@ async def test_get_token_basic(auth_users):
assert t2 == Token(
id=t2.id,
user="user_random1",
type=TokenType.AGENT,
cachefor=300000,
created=t2.created,
expires=t2.expires,
Expand All @@ -141,9 +159,11 @@ async def test_get_token_basic(auth_users):
assert time_close_to_now(t2.created, 10)
assert t2.expires - t2.created == 3600000

# for the remaining tokens we just check mfa
# for the remaining tokens we just check mfa & type
assert t3.mfa == MFAStatus.NOT_USED
assert t3.type == TokenType.DEVELOPER
assert t4.mfa == MFAStatus.UNKNOWN
assert t4.type == TokenType.SERVICE


@pytest.mark.asyncio
Expand Down
Loading