diff --git a/src/kbase/auth/_async/client.py b/src/kbase/auth/_async/client.py index a5af7ea..0e3f5b8 100644 --- a/src/kbase/auth/_async/client.py +++ b/src/kbase/auth/_async/client.py @@ -38,6 +38,20 @@ class Token: _VALID_TOKEN_FIELDS = {f.name for f in fields(Token)} +@dataclass +class User: + """ Information about a KBase user. """ + user: str + """ The username of the user associated with the token. """ + customroles: list[str] + """ The Auth2 custom roles the user possesses. """ + # Not seeing any other fields that are generally useful right now + # Don't really want to expose idents unless there's a very good reason + + +_VALID_USER_FIELDS = {f.name for f in fields(User)} + + def _require_string(putative: str, name: str) -> str: if not isinstance(putative, str) or not putative.strip(): raise ValueError(f"{name} is required and cannot be a whitespace only string") @@ -98,7 +112,7 @@ async def create( await cli.close() raise # TODO CLIENT look through the myriad of auth clients to see what functionality we need - # TODO CLIENT cache user using cachefor value from token + # TODO CLIENT cache valid user names using cachefor value from token # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all reads only return cli @@ -114,6 +128,7 @@ def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int if not timer: raise ValueError("timer is required") self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer) + self._user_cache = LRUCache(maxsize=cache_max_size, timer=timer) self._cli = httpx.AsyncClient() async def __aenter__(self): @@ -157,3 +172,29 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> # in test mode self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) return tk + + async def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) -> User: + """ + Get information about a KBase user. This method caches the user; + further caching is unnecessary in most cases. + + If you just need the user name get_token is potentially cheaper. + + token - the token of the user to query. + on_cache_miss - a function to call if a cache miss occurs. + """ + # really similar to the above, not quite similar enough to make a shared method + _require_string(token, "token") + user = self._user_cache.get(token, default=False) + if user: + return user + if on_cache_miss: + on_cache_miss() + tk = await self.get_token(token) + res = await self._get(self._me_url, headers={"Authorization": token}) + u = User(**{k: v for k, v in res.items() if k in _VALID_USER_FIELDS}) + # TODO TEST later may want to add tests that change the cachefor value. + # Cleanest way to do this is update the auth2 service to allow setting it + # in test mode + self._user_cache.set(token, u, ttl=tk.cachefor / 1000) + return u diff --git a/src/kbase/auth/_sync/client.py b/src/kbase/auth/_sync/client.py index e5e7af2..6921b75 100644 --- a/src/kbase/auth/_sync/client.py +++ b/src/kbase/auth/_sync/client.py @@ -38,6 +38,20 @@ class Token: _VALID_TOKEN_FIELDS = {f.name for f in fields(Token)} +@dataclass +class User: + """ Information about a KBase user. """ + user: str + """ The username of the user associated with the token. """ + customroles: list[str] + """ The Auth2 custom roles the user possesses. """ + # Not seeing any other fields that are generally useful right now + # Don't really want to expose idents unless there's a very good reason + + +_VALID_USER_FIELDS = {f.name for f in fields(User)} + + def _require_string(putative: str, name: str) -> str: if not isinstance(putative, str) or not putative.strip(): raise ValueError(f"{name} is required and cannot be a whitespace only string") @@ -85,7 +99,8 @@ def create( base_url - the base url for the authentication service, for example https://kbase.us/services/auth - cache_max_size - the maximum size of the token and user caches. + cache_max_size - the maximum size of the token and user caches. When the cache size is + exceeded, the least recently used entries are evicted from the cache. timer - the timer for the cache. Used for testing. Time unit must be seconds. """ cli = cls(base_url, cache_max_size, timer) @@ -97,7 +112,7 @@ def create( cli.close() raise # TODO CLIENT look through the myriad of auth clients to see what functionality we need - # TODO CLIENT cache user using cachefor value from token + # TODO CLIENT cache valid user names using cachefor value from token # TODO RELIABILITY could add retries for these methods, tenacity looks useful # should be safe since they're all reads only return cli @@ -113,6 +128,7 @@ def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int if not timer: raise ValueError("timer is required") self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer) + self._user_cache = LRUCache(maxsize=cache_max_size, timer=timer) self._cli = httpx.Client() def __enter__(self): @@ -156,3 +172,29 @@ def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token # in test mode self._token_cache.set(token, tk, ttl=tk.cachefor / 1000) return tk + + def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) -> User: + """ + Get information about a KBase user. This method caches the user; + further caching is unnecessary in most cases. + + If you just need the user name get_token is potentially cheaper. + + token - the token of the user to query. + on_cache_miss - a function to call if a cache miss occurs. + """ + # really similar to the above, not quite similar enough to make a shared method + _require_string(token, "token") + user = self._user_cache.get(token, default=False) + if user: + return user + if on_cache_miss: + on_cache_miss() + tk = self.get_token(token) + res = self._get(self._me_url, headers={"Authorization": token}) + u = User(**{k: v for k, v in res.items() if k in _VALID_USER_FIELDS}) + # TODO TEST later may want to add tests that change the cachefor value. + # Cleanest way to do this is update the auth2 service to allow setting it + # in test mode + self._user_cache.set(token, u, ttl=tk.cachefor / 1000) + return u diff --git a/test/test_client.py b/test/test_client.py index e7f9d35..0c7d873 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -213,3 +213,127 @@ async def test_get_token_cache_evict_on_time(auth_users): ttt1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss) assert cachemiss.call_count == 2 assert ttt1 == t1 + + +@pytest.mark.asyncio +async def test_get_user_basic(auth_users): + with KBaseAuthClient.create(AUTH_URL) as cli: + u1 = cli.get_user(auth_users["user"]) + u2 = cli.get_user(auth_users["user_all"]) + async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli: + u3 = await cli.get_user(auth_users["user_random1"]) + u4 = await cli.get_user(auth_users["user_random2"]) + + assert u1.user == "user" + assert u1.customroles == [] + + assert u2.user == "user_all" + assert u2.customroles == ["random1", "random2"] + + assert u3.user == "user_random1" + assert u3.customroles == ["random1"] + + assert u4.user == "user_random2" + assert u4.customroles == ["random2"] + + +@pytest.mark.asyncio +async def test_get_user_basic_fail(auth_users): + err = "token is required and cannot be a whitespace only string" + await _get_user_basic_fail(None, ValueError(err)) + await _get_user_basic_fail(" \t ", ValueError(err)) + err = "KBase auth server reported token is invalid." + await _get_user_basic_fail("superfake", InvalidTokenError(err)) + + +async def _get_user_basic_fail(token: str, expected: Exception): + with KBaseAuthClient.create(AUTH_URL) as cli: + with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): + cli.get_user(token) + async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli: + with pytest.raises(type(expected), match=f"^{expected.args[0]}$"): + await cli.get_user(token) + + +@pytest.mark.asyncio +async def test_get_user_cache_evict_on_size(auth_users): + with KBaseAuthClient.create(AUTH_URL, cache_max_size=3) as cli: + cachemiss = Mock() + # fill the cache + u1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + u2 = cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss) + u3 = cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + # check userss in cache + uu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + uu2 = cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss) + uu3 = cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + assert uu1 == u1 + assert uu2 == u2 + assert uu3 == u3 + # Force an eviction + cli.get_user(auth_users["user_all"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 4 + # Check user was evicted + uuu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 5 + assert uuu1 == u1 + + async with await AsyncKBaseAuthClient.create(AUTH_URL, cache_max_size=3) as cli: + cachemiss = Mock() + # fill the cache + u1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + u2 = await cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss) + u3 = await cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + # check users in cache + uu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + uu2 = await cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss) + uu3 = await cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 3 + assert uu1 == u1 + assert uu2 == u2 + assert uu3 == u3 + # Force an eviction + await cli.get_user(auth_users["user_all"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 4 + # Check user was evicted + uuu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 5 + assert uuu1 == u1 + + +@pytest.mark.asyncio +async def test_get_user_cache_evict_on_time(auth_users): + timer = FakeTimer() + with KBaseAuthClient.create(AUTH_URL, timer=timer) as cli: + cachemiss = Mock() + u1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + # TODO TEST auth2 always returns 300000 ms for cachefor. Update testmode to allow + # setting different values and test here + timer.advance(299) + uu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + assert uu1 == u1 + timer.advance(2) + uuu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 2 + assert uuu1 == u1 + + timer = FakeTimer() + async with await AsyncKBaseAuthClient.create(AUTH_URL, timer=timer) as cli: + cachemiss = Mock() + u1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + # TODO TEST auth2 always returns 300000 ms for cachefor. Update testmode to allow + # setting different values and test here + timer.advance(299) + uu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 1 + assert uu1 == u1 + timer.advance(2) + uuu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss) + assert cachemiss.call_count == 2 + assert uuu1 == u1