diff --git a/sample_annotator/clients/gold_client.py b/sample_annotator/clients/gold_client.py index 75da84b..cfd0201 100644 --- a/sample_annotator/clients/gold_client.py +++ b/sample_annotator/clients/gold_client.py @@ -13,6 +13,11 @@ from requests.auth import HTTPBasicAuth +import hashlib +import pickle + +cache = None + USERPASS = Tuple[str, str] URL = str @@ -33,10 +38,32 @@ # but leaving this in as a stub in case this happens again EXCLUSION_LIST = [] +def build_cache_key(endpoint_url: str, params: dict, user: str, passwd: str) -> str: + normalized_params = tuple(sorted(params.items())) + key_data = (endpoint_url, normalized_params, user, passwd) + return key_data + +def ensure_cache_initialized(): + global cache + if cache is None: + path = os.environ.get("GOLD_CACHE_DIR", "cachedir") + logging.warning(f"[cache] cache was None. Falling back to: {path}") + cache = Cache(path) + +def set_cache_directory(path: str): + """ + Override the default cache directory globally. + Must be called before any memoized function is used. + """ + global cache + # Clear any previous cache instance + if cache is not None: + cache.close() + cache = Cache(path) + logging.info(f"Cache directory set to: {path}") -@cache.memoize() -@cache.memoize() def _fetch_url(endpoint_url, params, user, passwd) -> JSON: + logging.warning(f"[cache] API call made: {endpoint_url} {params}") attempt = 0 while attempt < 4: results = requests.get( @@ -53,6 +80,22 @@ def _fetch_url(endpoint_url, params, user, passwd) -> JSON: attempt += 1 raise Exception(f"API call to {endpoint_url} failed after {attempt} attempts") +def get_fetch_url(): + ensure_cache_initialized() + + @cache.memoize() + def _memoized_fetch_url(endpoint_url, params, user, passwd): + key_data = build_cache_key(endpoint_url, params, user, passwd) + key_bytes = pickle.dumps(key_data) + hashed_key = hashlib.sha256(key_bytes).hexdigest() + + study_id = params.get("studyGoldId", "") + logging.warning(f"[cache] memoize key: {hashed_key} for studyGoldId={study_id}") + + return _fetch_url(endpoint_url, params, user, passwd) + + return _memoized_fetch_url + class GoldClient: """ @@ -91,7 +134,23 @@ def _normalize_id(self, id: str) -> str: def _call(self, endpoint: str, params: Dict = {}) -> JSON: (user, passwd) = self.gold_key endpoint_url = f"{self.url}/{endpoint}" - obj = _fetch_url(endpoint_url, params, user, passwd) + + # Try direct cache lookup first + key_data = build_cache_key(endpoint_url, params, user, passwd) + if cache and key_data in cache: + logging.info(f"[cache] Direct cache hit for {endpoint} {params}") + return cache.get(key_data) + + # Also try with simpler key format that might be in cache + study_id = params.get("studyGoldId", "") + if study_id: + simple_key = f"{endpoint_url}-{study_id}-{user}-{passwd}" + if cache and simple_key in cache: + logging.info(f"[cache] Simple key cache hit for {endpoint} {params}") + return cache.get(simple_key) + + # Fall back to memoized function if not found directly + obj = get_fetch_url()(endpoint_url, params, user, passwd) self.num_calls += 1 return obj @@ -109,7 +168,7 @@ def fetch_projects_by_study(self, id: str) -> List[SampleDict]: return results def fetch_biosamples_by_study( - self, id: str, include_project=True + self, id: str, include_project=True ) -> List[SampleDict]: """ Fetches all samples for a study @@ -124,6 +183,8 @@ def fetch_biosamples_by_study( else: biosamples = self._call("biosamples", {"studyGoldId": id}) if include_project: + logging.info(f"Cache contains {len(cache) if cache else 0} records") + logging.info(f"Cache directory is at {cache.directory if cache else 'None'}") projects = self.fetch_projects_by_study(id) # weave projects in samples samples_by_id = {