Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions sample_annotator/clients/gold_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from requests.auth import HTTPBasicAuth

import hashlib
import pickle

cache = None


USERPASS = Tuple[str, str]
URL = str
Expand All @@ -33,10 +38,32 @@
# but leaving this in as a stub in case this happens again
EXCLUSION_LIST = []

def build_cache_key(endpoint_url: str, params: dict, user: str, passwd: str) -> str:
normalized_params = tuple(sorted(params.items()))
key_data = (endpoint_url, normalized_params, user, passwd)
return key_data

def ensure_cache_initialized():
global cache
if cache is None:
path = os.environ.get("GOLD_CACHE_DIR", "cachedir")
logging.warning(f"[cache] cache was None. Falling back to: {path}")
cache = Cache(path)

def set_cache_directory(path: str):
"""
Override the default cache directory globally.
Must be called before any memoized function is used.
"""
global cache
# Clear any previous cache instance
if cache is not None:
cache.close()
cache = Cache(path)
logging.info(f"Cache directory set to: {path}")

@cache.memoize()
@cache.memoize()
def _fetch_url(endpoint_url, params, user, passwd) -> JSON:
logging.warning(f"[cache] API call made: {endpoint_url} {params}")
attempt = 0
while attempt < 4:
results = requests.get(
Expand All @@ -53,6 +80,22 @@ def _fetch_url(endpoint_url, params, user, passwd) -> JSON:
attempt += 1
raise Exception(f"API call to {endpoint_url} failed after {attempt} attempts")

def get_fetch_url():
ensure_cache_initialized()

@cache.memoize()
def _memoized_fetch_url(endpoint_url, params, user, passwd):
key_data = build_cache_key(endpoint_url, params, user, passwd)
key_bytes = pickle.dumps(key_data)
hashed_key = hashlib.sha256(key_bytes).hexdigest()

study_id = params.get("studyGoldId", "<none>")
logging.warning(f"[cache] memoize key: {hashed_key} for studyGoldId={study_id}")

return _fetch_url(endpoint_url, params, user, passwd)

return _memoized_fetch_url


class GoldClient:
"""
Expand Down Expand Up @@ -91,7 +134,23 @@ def _normalize_id(self, id: str) -> str:
def _call(self, endpoint: str, params: Dict = {}) -> JSON:
(user, passwd) = self.gold_key
endpoint_url = f"{self.url}/{endpoint}"
obj = _fetch_url(endpoint_url, params, user, passwd)

# Try direct cache lookup first
key_data = build_cache_key(endpoint_url, params, user, passwd)
if cache and key_data in cache:
logging.info(f"[cache] Direct cache hit for {endpoint} {params}")
return cache.get(key_data)

# Also try with simpler key format that might be in cache
study_id = params.get("studyGoldId", "")
if study_id:
simple_key = f"{endpoint_url}-{study_id}-{user}-{passwd}"
if cache and simple_key in cache:
logging.info(f"[cache] Simple key cache hit for {endpoint} {params}")
return cache.get(simple_key)

# Fall back to memoized function if not found directly
obj = get_fetch_url()(endpoint_url, params, user, passwd)
self.num_calls += 1
return obj

Expand All @@ -109,7 +168,7 @@ def fetch_projects_by_study(self, id: str) -> List[SampleDict]:
return results

def fetch_biosamples_by_study(
self, id: str, include_project=True
self, id: str, include_project=True
) -> List[SampleDict]:
"""
Fetches all samples for a study
Expand All @@ -124,6 +183,8 @@ def fetch_biosamples_by_study(
else:
biosamples = self._call("biosamples", {"studyGoldId": id})
if include_project:
logging.info(f"Cache contains {len(cache) if cache else 0} records")
logging.info(f"Cache directory is at {cache.directory if cache else 'None'}")
projects = self.fetch_projects_by_study(id)
# weave projects in samples
samples_by_id = {
Expand Down