From 8316341e92f668b30c9d034351750a6144d273ae Mon Sep 17 00:00:00 2001 From: jonasscheid Date: Fri, 6 Feb 2026 15:12:44 +0000 Subject: [PATCH] Support comma-separated multiple categories in download-all-public-category-files --- pridepy/files/files.py | 27 ++++++++++++++++++--------- pridepy/pridepy.py | 17 +++++++++++++---- pridepy/tests/test_raw_files.py | 13 +++++++++++++ 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/pridepy/files/files.py b/pridepy/files/files.py index 0d590ef..545dd1f 100644 --- a/pridepy/files/files.py +++ b/pridepy/files/files.py @@ -725,10 +725,11 @@ def download_all_category_files( protocol: str, aspera_maximum_bandwidth: str, checksum_check: bool, - category: str, + categories: List[str] = None, + category: str = None, ): """ - Download all files of a specified category from a PRIDE project. + Download all files of specified categories from a PRIDE project. :param accession: The PRIDE project accession identifier. :param output_folder: The directory where the files will be downloaded. @@ -736,9 +737,12 @@ def download_all_category_files( :param protocol: The transfer protocol to use (e.g., ftp, aspera, globus, s3). :param aspera_maximum_bandwidth: Maximum bandwidth for Aspera transfers. :param checksum_check: If True, downloads the checksum file for the project. - :param category: The category of files to download. + :param categories: List of file categories to download. + :param category: Single file category (deprecated, use categories instead). """ - raw_files = self.get_all_category_file_list(accession, category) + if categories is None: + categories = [category] if category else ["RAW"] + raw_files = self.get_all_category_file_list(accession, categories) self.download_files( raw_files, accession, @@ -749,17 +753,22 @@ def download_all_category_files( checksum_check=checksum_check, ) - def get_all_category_file_list(self, accession: str, category: str): + def get_all_category_file_list( + self, accession: str, categories: "str | List[str]" + ) -> List[Dict]: """ - Retrieve a list of files from a specific project that belong to a given category. + Retrieve a list of files from a specific project that belong to given categories. :param accession: The PRIDE project accession identifier. - :param category: The category of files to filter by. - :return: A list of files in the specified category. + :param categories: A single category string or list of categories to filter by. + :return: A list of files matching the specified categories. """ record_files = self.stream_all_files_by_project(accession) + if isinstance(categories, str): + categories = [categories] + category_set = set(categories) category_files = [ - file for file in record_files if file["fileCategory"]["value"] == category + file for file in record_files if file["fileCategory"]["value"] in category_set ] return category_files diff --git a/pridepy/pridepy.py b/pridepy/pridepy.py index 74afe55..aa8fffb 100644 --- a/pridepy/pridepy.py +++ b/pridepy/pridepy.py @@ -124,8 +124,8 @@ def download_all_public_raw_files( "-c", "--category", required=True, - help="Category of the files to be downloaded", - type=click.Choice("RAW,PEAK,SEARCH,RESULT,SPECTRUM_LIBRARY,OTHER,FASTA".split(",")), + help="Comma-separated categories of files to download (e.g. RAW or RAW,SEARCH). " + "Valid values: RAW, PEAK, SEARCH, RESULT, SPECTRUM_LIBRARY, OTHER, FASTA", ) def download_all_public_category_files( accession: str, @@ -146,9 +146,18 @@ def download_all_public_category_files( skip_if_downloaded_already (bool): If True, skips downloading files that already exist. Default is False. aspera_maximum_bandwidth (str): Maximum bandwidth for Aspera transfers. checksum_check (bool): If True, downloads the checksum file for the project. - category (str): The category of files to download. + category (str): Comma-separated categories of files to download (e.g. RAW or RAW,SEARCH). """ + valid_categories = {"RAW", "PEAK", "SEARCH", "RESULT", "SPECTRUM_LIBRARY", "OTHER", "FASTA"} + categories = [c.strip().upper() for c in category.split(",")] + invalid = set(categories) - valid_categories + if invalid: + raise click.BadParameter( + f"Invalid category: {', '.join(invalid)}. " + f"Valid values: {', '.join(sorted(valid_categories))}" + ) + raw_files = Files() logging.info("accession: " + accession) logging.info(f"Data will be downloaded from {protocol}") @@ -163,7 +172,7 @@ def download_all_public_category_files( protocol, aspera_maximum_bandwidth=aspera_maximum_bandwidth, checksum_check=checksum_check, - category=category, + categories=categories, ) diff --git a/pridepy/tests/test_raw_files.py b/pridepy/tests/test_raw_files.py index 631b8cf..1ce2ca3 100644 --- a/pridepy/tests/test_raw_files.py +++ b/pridepy/tests/test_raw_files.py @@ -37,3 +37,16 @@ def test_get_all_category_file_list(self): result = raw.get_all_category_file_list("PXD008644", "SEARCH") assert len(result) == 2 + + def test_get_all_category_file_list_multiple(self): + """ + Test filtering by multiple categories at once. + PXD008644 has 2 RAW + 2 SEARCH = 4 files combined. + """ + raw = Files() + result = raw.get_all_category_file_list("PXD008644", ["RAW", "SEARCH"]) + assert len(result) == 4 + + # Verify both categories are present + categories = {file["fileCategory"]["value"] for file in result} + assert categories == {"RAW", "SEARCH"}