From d74d525bdd05e4b8a4bb446af7dd488123dbae9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 26 Jan 2026 10:37:26 +0100 Subject: [PATCH 1/6] Separate update_version methods for project and output repo --- cadetrdm/repositories.py | 136 ++++++++++++++++++++++++++------------- 1 file changed, 92 insertions(+), 44 deletions(-) diff --git a/cadetrdm/repositories.py b/cadetrdm/repositories.py index bb53e8d..7443f91 100644 --- a/cadetrdm/repositories.py +++ b/cadetrdm/repositories.py @@ -822,6 +822,8 @@ def __init__( self._output_uuid = self._metadata["output_uuid"] self._output_folder = self._metadata["output_remotes"]["output_folder_name"] self.options = options + self._update_version() + if not (self.path / self._output_folder).exists(): print("Output repository was missing, cloning now.") self._clone_output_repo() @@ -830,9 +832,6 @@ def __init__( self, ) - if self._metadata["cadet_rdm_version"] != cadetrdm.__version__: - self._update_version(self._metadata, cadetrdm.__version__) - self._on_context_enter_commit_hash = None self._is_in_context_manager = False self.options_hash = None @@ -864,58 +863,33 @@ def module(self) -> ModuleType: sys.path.remove(str(self.path)) os.chdir(cur_dir) - def _update_version(self, metadata, cadetrdm_version): - current_version = Version.coerce(metadata["cadet_rdm_version"]) + def _update_version(self) -> None: + """Update project repo to latest CADET-RDM specs.""" + metadata = self._metadata + cadetrdm_version = Version(cadetrdm.__version__) + current_version = Version(metadata["cadet_rdm_version"]) + + # Skip if versions match + if cadetrdm_version == current_version: + return changes_were_made = False - if SimpleSpec("<0.0.9").match(current_version): + if current_version < Version("0.0.9"): changes_were_made = True - self.output_repo._convert_csv_to_tsv_if_necessary() self._add_jupytext_file(self.path) - if SimpleSpec("<0.0.24").match(current_version): + if current_version < Version("0.0.24"): changes_were_made = True - self.output_repo._expand_tsv_header() output_remotes_path = self.path / "output_remotes.json" delete_path(output_remotes_path) self.add(output_remotes_path) - if SimpleSpec("<=0.0.34").match(current_version): - changes_were_made = True - if self.output_log_file.exists(): - warnings.warn( - "Repo version has outdated headers." - "Updating log.tsv." - ) - self.output_repo._update_headers() - if SimpleSpec("<0.0.34").match(current_version): - changes_were_made = True - self.output_repo._fix_gitattributes_log_tsv() - if SimpleSpec("<1.1.0").match(current_version): - # Note, this needs to be performed before upating the hashes, otherwise - # instantiating an `OutputLog` will crash when missing the - # `project_repo_branch` attribute. - changes_were_made = True - if self.output_log_file.exists(): - warnings.warn( - "Repo version has missing project repo branch_name field." - "Updating log.tsv." - ) - self.output_repo._add_branch_name_to_log() - if SimpleSpec("<0.1.7").match(current_version): - changes_were_made = True - if self.output_repo.output_log.n_entries > 0: - warnings.warn( - "Repo version has outdated options hashes. " - "Updating option hashes in output log.tsv." - ) - self.output_repo._update_log_hashes() if changes_were_made: print( f"Repo version {metadata['cadet_rdm_version']} was outdated. " f"Current CADET-RDM version is {cadetrdm.__version__}.\n" "Repo has been updated." ) - metadata["cadet_rdm_version"] = cadetrdm_version + metadata["cadet_rdm_version"] = str(cadetrdm_version) with open(self.data_json_path, "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2) self.add(self.data_json_path) @@ -1489,6 +1463,8 @@ def __init__( self.project_repo = project_repo super().__init__(*args, **kwargs) + self._update_version() + @property def output_log_file_path(self): if not self.active_branch == self.main_branch: @@ -1561,7 +1537,6 @@ def commit_to_options_map(self) -> dict[str, list[str]]: mapping[entry.project_repo_commit_hash].append(entry.options_hash) return dict(mapping) - def add_filetype_to_lfs(self, file_type): """ Add the filetype given in file_type to the GIT-LFS tracking @@ -1574,6 +1549,69 @@ def add_filetype_to_lfs(self, file_type): self.add_all_files() self.commit(f"Add {file_type} to lfs") + def _update_version(self) -> None: + """Update output repo to latest CADET-RDM specs.""" + metadata = self._metadata + cadetrdm_version = Version(cadetrdm.__version__) + current_version = Version(metadata["cadet_rdm_version"]) + + # Skip if versions match + if cadetrdm_version == current_version: + return + + changes_were_made = False + + if current_version < Version("0.0.9"): + changes_were_made = True + self._convert_csv_to_tsv_if_necessary() + if current_version < Version("0.0.24"): + changes_were_made = True + self._expand_tsv_header() + if current_version < Version("0.0.34"): + changes_were_made = True + if self.output_log_file_path.exists(): + warnings.warn( + "Repo version has outdated headers. " + "Updating log.tsv." + ) + self._update_headers() + if current_version < Version("0.0.34"): + changes_were_made = True + self._fix_gitattributes_log_tsv() + if current_version < Version("1.1.0"): + # Note, this needs to be performed before upating the hashes, otherwise + # instantiating an `OutputLog` will crash when missing the + # `project_repo_branch` attribute. + changes_were_made = True + if self.output_log_file_path.exists(): + warnings.warn( + "Repo version has missing project repo branch_name field. " + "Updating log.tsv." + ) + self._add_branch_name_to_log() + if current_version < Version("0.1.7"): + changes_were_made = True + if self.output_log.n_entries > 0: + warnings.warn( + "Repo version has outdated options hashes. " + "Updating option hashes in output log.tsv." + ) + self._update_log_hashes() + if changes_were_made: + print( + f"Repo version {metadata['cadet_rdm_version']} was outdated. " + f"Current CADET-RDM version is {cadetrdm.__version__}.\n" + "Repo has been updated." + ) + metadata["cadet_rdm_version"] = str(cadetrdm_version) + with open(self.data_json_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + self.add(self.data_json_path) + self.commit( + f"Update CADET-RDM version to {cadetrdm_version}", + add_all=False + ) + def _convert_csv_to_tsv_if_necessary(self) -> None: """Convert logfile from csv to tsv format.""" if self.output_log_file_path.exists(): @@ -1596,6 +1634,9 @@ def _convert_csv_to_tsv_if_necessary(self) -> None: lines=["rdm-log.tsv merge=union"], open_type="a" ) + self.add(self.path / "log.csv") + self.add(self.path / "log.tsv") + self.commit("Convert csv to tsv", add_all=False) def _expand_tsv_header(self): """Update tsv header.""" @@ -1614,7 +1655,8 @@ def _expand_tsv_header(self): "Project repo remotes", "Python sys args", "Tags", - "Options hash", ] + "Options hash", + ] with open(self.output_log_file_path, "w", encoding="utf-8") as f: f.writelines(["\t".join(new_header) + "\n"]) f.writelines(lines[1:]) @@ -1656,6 +1698,7 @@ def _fix_gitattributes_log_tsv(self): lines = [line.replace("rdm-log.tsv", "log.tsv") for line in lines] with open(file, "w", encoding="utf-8") as handle: handle.writelines(lines) + self.add(".gitattributes") self.commit("Update .gitattributes", add_all=False) @@ -1680,7 +1723,9 @@ def _update_log_hashes(self): self.checkout(self.main_branch) if self.output_log.n_entries > 0: log.write() - self.commit(message="Updated log hashes", add_all=True) + + self.add(self.output_log_file_path) + self.commit(message="Updated log hashes", add_all=False) def _add_branch_name_to_log(self) -> None: """ @@ -1713,7 +1758,10 @@ def _add_branch_name_to_log(self) -> None: writer.writerows(rows) self.add("log.tsv") - self.commit(message="Add project_repo_branch_name to log.tsv") + self.commit( + message="Add project_repo_branch_name to log.tsv", + add_all=False, + ) class JupyterInterfaceRepo(ProjectRepo): From bc979d5e35bfb614bd8547ed4d75716d43e8244c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 26 Jan 2026 11:44:24 +0100 Subject: [PATCH 2/6] Pass all args to super().commit --- cadetrdm/repositories.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/cadetrdm/repositories.py b/cadetrdm/repositories.py index 7443f91..f94628d 100644 --- a/cadetrdm/repositories.py +++ b/cadetrdm/repositories.py @@ -1,5 +1,6 @@ import contextlib import csv +from functools import wraps import glob import importlib import json @@ -285,7 +286,12 @@ def remote_set_url(self, name: str, url: str): """ self._git_repo.remotes[name].set_url(url) - def commit(self, message: str | None = None, add_all=True, verbosity=1): + def commit( + self, + message: str | None = None, + add_all=True, + verbosity=1, + ) -> None: """ Commit current state of the repository. @@ -1079,8 +1085,7 @@ def update_output_main_logs(self, output_dict: dict = None): self.output_repo.add(".") self.output_repo._git.commit( "-m", - f"log for '{output_commit_message}' \n" - f"of branch '{output_branch_name}'" + f"log for '{output_commit_message}' of branch '{output_branch_name}'", ) self.output_repo._git.checkout(output_branch_name) @@ -1103,20 +1108,15 @@ def _copy_code(self, target_path): self.active_branch, output=code_tmp_folder ) - def commit(self, message: str | None = None, add_all=True, verbosity=1): - """ - Commit current state of the repository. - - :param message: - Commit message - :param add_all: - Option to add all changed and new files to git automatically. - :param verbosity: - Option to choose degree of printed feedback. - """ + @wraps(BaseRepo.commit) + def commit( + self, + *args, + **kwargs, + ): + """Update output remotes before committing.""" self.update_output_remotes_json() - - super().commit(message=message, add_all=add_all, verbosity=verbosity) + super().commit(*args, **kwargs) def check(self, commit=True): """ From 6fe2a7e892176f9b493da5200d2cc03402c32536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 26 Jan 2026 10:47:44 +0100 Subject: [PATCH 3/6] Remove options from project repo --- cadetrdm/repositories.py | 88 ++++++++++++++++++++++++++++------------ cadetrdm/wrapper.py | 7 ++-- 2 files changed, 66 insertions(+), 29 deletions(-) diff --git a/cadetrdm/repositories.py b/cadetrdm/repositories.py index f94628d..ad2c164 100644 --- a/cadetrdm/repositories.py +++ b/cadetrdm/repositories.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import csv from functools import wraps @@ -773,7 +775,6 @@ def __init__( suppress_lfs_warning: bool = False, url: str = None, branch: str = None, - options: Options | None = None, package_dir: str | None = None, *args: Any, **kwargs: Any, @@ -798,8 +799,6 @@ def __init__( Optional branch to check out upon initialization :param package_dir: Name of the directory containing the main package. - :param options: - Options dictionary containing ... :param args: Additional args to be handed to BaseRepo. :param kwargs: @@ -827,7 +826,7 @@ def __init__( self._project_uuid = self._metadata["project_uuid"] self._output_uuid = self._metadata["output_uuid"] self._output_folder = self._metadata["output_remotes"]["output_folder_name"] - self.options = options + self._update_version() if not (self.path / self._output_folder).exists(): @@ -840,7 +839,6 @@ def __init__( self._on_context_enter_commit_hash = None self._is_in_context_manager = False - self.options_hash = None if branch is not None: self.checkout(branch) @@ -959,9 +957,10 @@ def create_remotes(self, name, namespace, url=None, username=None, push=True): if errors_encountered == 0 and push: self.push(push_all=True) - def get_new_output_branch_name(self): + def get_new_output_branch_name(self, branch_prefix: str | None = None) -> str: """ Construct a name for the new branch in the output repository. + :param branch_prefix: Optional branch name prefix. :return: the new branch name """ project_repo_hash = str(self.head.commit) @@ -969,8 +968,8 @@ def get_new_output_branch_name(self): branch_name = f"{timestamp}_{self.active_branch}_{project_repo_hash[:7]}" - if self.options and "branch_prefix" in self.options: - branch_name = f"{self.options['branch_prefix']}_{branch_name}" + if branch_prefix: + branch_name = f"{branch_prefix}_{branch_name}" return branch_name @@ -1030,7 +1029,11 @@ def output_log_file(self): def output_log(self): return self.output_repo.output_log - def update_output_main_logs(self, output_dict: dict = None): + def update_output_main_logs( + self, + output_dict: dict = None, + options: Options | None = None, + ): """ Dumps all the metadata information about the project repositories state and the commit hash and branch name of the ouput repository into the main branch of @@ -1063,7 +1066,7 @@ def update_output_main_logs(self, output_dict: dict = None): project_repo_remotes=self.remote_urls, python_sys_args=str(sys.argv), tags=", ".join(self.tags), - options_hash=self.options_hash, + options_hash=options.get_hash() if options else None, filepath=None, **output_dict ) @@ -1071,8 +1074,8 @@ def update_output_main_logs(self, output_dict: dict = None): with open(logs_dir / "metadata.json", "w", encoding="utf-8") as f: json.dump(entry.to_dict(), f, indent=2) - if self.options is not None: - self.options.dump_json_file(logs_dir / "options.json", indent=2) + if options: + options.dump_json_file(logs_dir / "options.json", indent=2) log = OutputLog(self.output_log_file) log.entries[output_branch_name] = entry @@ -1217,7 +1220,12 @@ def import_static_data(self, source_path: Path | str, commit_message): self._commit_output_data(commit_message, output_dict={}) return new_branch_name - def enter_context(self, force=False, debug=False): + def enter_context( + self, + force=False, + debug=False, + branch_prefix: str | None = None, + ) -> str | None: """ Enter the tracking context. This includes: - Ensure no uncommitted changes in the project repository @@ -1225,11 +1233,12 @@ def enter_context(self, force=False, debug=False): - Clean up empty branches in the output repository - Create a new empty output branch in the output repository - :param force: If False, wait for user prompts before deleting data during clean up. If True, don't wait, just delete. :param debug: If True, just return None. + :param branch_prefix: + Optional branch name prefix. :return: The name of the newly created output branch. """ @@ -1244,10 +1253,14 @@ def enter_context(self, force=False, debug=False): self._on_context_enter_commit_hash = self.current_commit_hash self._is_in_context_manager = True - new_branch_name = self._get_new_output_branch(force) + new_branch_name = self._get_new_output_branch(force, branch_prefix) return new_branch_name - def _get_new_output_branch(self, force=False): + def _get_new_output_branch( + self, + force: bool = False, + branch_prefix: str | None = None + ): """ Prepares a new branch to receive data. This includes: - checking out the output main branch, @@ -1257,8 +1270,9 @@ def _get_new_output_branch(self, force=False): :param force: If False, wait for user prompts before deleting data during clean up. If True, don't wait, just delete. + :param branch_prefix: + Optional branch name prefix. """ - output_repo = self.output_repo # ensure that LFS is properly initialized @@ -1268,7 +1282,7 @@ def _get_new_output_branch(self, force=False): if output_repo.has_uncomitted_changes: output_repo._reset_hard_to_head(force_entry=force) output_repo.delete_active_branch_if_branch_is_empty() - new_branch_name = self.get_new_output_branch_name() + new_branch_name = self.get_new_output_branch_name(branch_prefix) # update urls in main branch of output_repo output_repo._git.checkout(output_repo.main_branch) @@ -1367,7 +1381,12 @@ def copy_data_to_cache(self, branch_name=None, target_folder=None): return target_folder - def exit_context(self, message, output_dict: dict = None): + def exit_context( + self, + message, + output_dict: dict = None, + options: Options | None = None, + ): """ After running all project code, this prepares the commit of the results to the output repository. This includes - Ensure no uncommitted changes in the project repository @@ -1386,9 +1405,14 @@ def exit_context(self, message, output_dict: dict = None): if self._on_context_enter_commit_hash != self.current_commit_hash: raise RuntimeError("Code has changed since starting the context. Don't do that.") - self._commit_output_data(message, output_dict) + self._commit_output_data(message, output_dict, options) - def _commit_output_data(self, message, output_dict): + def _commit_output_data( + self, + message: str, + output_dict: dict, + options: Options | None = None + ): """ Commit the data in the output repository. - Stage all changes in the output repository @@ -1398,6 +1422,8 @@ def _commit_output_data(self, message, output_dict): Commit message for the output repository commit. :param output_dict: Dictionary containing optional output tracking parameters + :param options: + Optional case options. """ print("Completed computations, commiting results") self.output_repo.add(".") @@ -1405,7 +1431,7 @@ def _commit_output_data(self, message, output_dict): # This has to be using ._git.commit to raise an error if no results have been written. commit_return = self.output_repo._git.commit("-m", message) self.copy_data_to_cache() - self.update_output_main_logs(output_dict) + self.update_output_main_logs(output_dict, options) main_cach_path = self.path / (self._output_folder + "_cached") / self.output_repo.main_branch if main_cach_path.exists(): delete_path(main_cach_path) @@ -1419,7 +1445,13 @@ def _commit_output_data(self, message, output_dict): self._on_context_enter_commit_hash = None @contextlib.contextmanager - def track_results(self, results_commit_message: str, debug=False, force=False): + def track_results( + self, + results_commit_message: str, + debug=False, + force=False, + options: Options | None = None, + ) -> str | None: """ Context manager to be used when running project code that produces output that should be tracked in the output repository. @@ -1429,6 +1461,8 @@ def track_results(self, results_commit_message: str, debug=False, force=False): Perform calculations without tracking output. :param force: Skip confirmation and force tracking of results. + :param options: + Optional case options. """ if debug: yield "debug" @@ -1439,14 +1473,18 @@ def track_results(self, results_commit_message: str, debug=False, force=False): yield "detached_head" return - new_branch_name = self.enter_context(force=force) + new_branch_name = self.enter_context( + force=force, + debug=debug, + branch_prefix=options.get("branch_prefix"), + ) try: yield new_branch_name except Exception as e: self.capture_error(e) raise e else: - self.exit_context(message=results_commit_message) + self.exit_context(message=results_commit_message, options=options) def capture_error(self, error): print(traceback.format_exc()) diff --git a/cadetrdm/wrapper.py b/cadetrdm/wrapper.py index 69248e0..486bb20 100644 --- a/cadetrdm/wrapper.py +++ b/cadetrdm/wrapper.py @@ -29,14 +29,13 @@ def wrapper(options, repo_path='.'): if options.get_hash() != Options.load_json_str(options.dump_json_str()).get_hash(): raise ValueError("Options are not serializable. Please only use python natives and numpy ndarrays.") - project_repo = ProjectRepo(repo_path, options=options) - - project_repo.options_hash = options.get_hash() + project_repo = ProjectRepo(repo_path) with project_repo.track_results( options.commit_message, debug=options.debug, - force=True + force=True, + options=options, ) as new_branch_name: options.dump_json_file(project_repo.output_path / "options.json") results = func(project_repo, options) From 3846e53216e8fa1d38ee9c5f3f21e9c9869475a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 26 Jan 2026 13:13:08 +0100 Subject: [PATCH 4/6] fixup! Remove options from project repo --- cadetrdm/repositories.py | 2 +- tests/test_git_adapter.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cadetrdm/repositories.py b/cadetrdm/repositories.py index ad2c164..37ae17f 100644 --- a/cadetrdm/repositories.py +++ b/cadetrdm/repositories.py @@ -1476,7 +1476,7 @@ def track_results( new_branch_name = self.enter_context( force=force, debug=debug, - branch_prefix=options.get("branch_prefix"), + branch_prefix=options.get("branch_prefix") if options else None, ) try: yield new_branch_name diff --git a/tests/test_git_adapter.py b/tests/test_git_adapter.py index f53f343..d407e95 100644 --- a/tests/test_git_adapter.py +++ b/tests/test_git_adapter.py @@ -430,3 +430,7 @@ def test_with_detached_head(): # repo.verify_unchanged_cache() # # os.chdir("..") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) From 78e409410ee6e4ffa8f00784ced26bb83ca08a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 26 Jan 2026 14:05:17 +0100 Subject: [PATCH 5/6] fixup! Remove options from project repo --- tests/test_options.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/test_options.py b/tests/test_options.py index b8108cf..c11e8b5 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -1,11 +1,13 @@ -import numpy as np +from pathlib import Path import re +import numpy as np +import pytest + from cadetrdm import Options from cadetrdm.options import remove_invalid_keys -from cadetrdm import process_example from cadetrdm import ProjectRepo -from pathlib import Path + def test_options_hash(): opt = Options() @@ -115,7 +117,7 @@ def test_branch_name(): options.push = False options.source_directory = "src" - repo = ProjectRepo(Path("./test_repo_cli"), options=options) + repo = ProjectRepo(Path("./test_repo_cli")) hash = str(repo.head.commit)[:7] active_branch = str(repo.active_branch) @@ -136,14 +138,18 @@ def test_branch_name_with_prefix(): options.source_directory = "src" options.branch_prefix = "Test_Prefix" - repo = ProjectRepo(Path("./test_repo_cli"), options=options) + repo = ProjectRepo(Path("./test_repo_cli")) hash = str(repo.head.commit)[:7] active_branch = str(repo.active_branch) - new_branch = repo.get_new_output_branch_name() + new_branch = repo.get_new_output_branch_name(options.branch_prefix) escaped_branch = re.escape(active_branch) pattern = rf"^Test_Prefix_\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}$" - assert re.match(pattern, new_branch), f"Branch name '{new_branch}' does not match expected format" \ No newline at end of file + assert re.match(pattern, new_branch), f"Branch name '{new_branch}' does not match expected format" + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) From cd813507fe5f4ede5ce2b79876fed442be02da27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 26 Jan 2026 14:06:11 +0100 Subject: [PATCH 6/6] Refactor test_options --- tests/test_options.py | 144 +++++++++++++++--------------------------- 1 file changed, 52 insertions(+), 92 deletions(-) diff --git a/tests/test_options.py b/tests/test_options.py index c11e8b5..5f305b6 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -1,14 +1,23 @@ -from pathlib import Path import re import numpy as np import pytest +from cadetrdm import initialize_repo from cadetrdm import Options from cadetrdm.options import remove_invalid_keys from cadetrdm import ProjectRepo +@pytest.fixture +def clean_repo(tmp_path): + """Fixture to initialize and clean up a test repository.""" + repo_path = tmp_path / "test_repo_cli" + initialize_repo(repo_path) + repo = ProjectRepo(repo_path) + yield repo + + def test_options_hash(): opt = Options() opt["array"] = np.linspace(2, 200) @@ -21,116 +30,71 @@ def test_options_hash(): assert opt == opt_recovered -def test_options_file_io(): +def test_options_file_io(tmp_path): opt = Options() opt["array"] = np.linspace(0, 2, 200) opt["nested_dict"] = {"ba": "foo", "bb": "bar"} initial_hash = opt.get_hash() - opt.dump_json_file("options.json") - opt_recovered = Options.load_json_file("options.json") + opt.dump_json_file(tmp_path / "options.json") + opt_recovered = Options.load_json_file(tmp_path / "options.json") post_serialization_hash = opt_recovered.get_hash() assert initial_hash == post_serialization_hash assert opt == opt_recovered -def test_remove_keys_starting_with_underscore(): - input_dict = { - "_private": 1, - "valid": 2, - "__magic__": 3 - } - expected = {"valid": 2} - assert remove_invalid_keys(input_dict) == expected - - -def test_remove_keys_containing_double_underscore(): - input_dict = { - "normal": 1, - "with__double": 2, - "another": 3 - } - expected = {"normal": 1, "another": 3} - assert remove_invalid_keys(input_dict) == expected - - -def test_nested_dict_removal(): - input_dict = { - "level1": { - "_invalid": 1, - "valid": { - "__still_invalid__": 2, - "ok": 3 - } - }, - "__should_be_removed__": "nope" - } - expected = { - "level1": { - "valid": { - "ok": 3 - } - } - } +@pytest.mark.parametrize( + "input_dict, expected", + [ + ({"_private": 1, "valid": 2, "__magic__": 3}, {"valid": 2}), + ( + { + "level1": { + "_invalid": 1, + "valid": {"__still_invalid__": 2, "ok": 3}, + }, + "__should_be_removed__": "nope", + }, + {"level1": {"valid": {"ok": 3}}}, + ), + ({}, {}), + ({"_one": 1, "__two__": 2, "with__double": 3}, {}), + ({"a": 1, "b": {"c": 2}}, {"a": 1, "b": {"c": 2}}), + ], + ids=[ + "keys_starting_with_underscore", + "nested_dict_removal", + "empty_dict", + "all_invalid_keys", + "no_invalid_keys", + ], +) +def test_remove_invalid_keys(input_dict, expected): assert remove_invalid_keys(input_dict) == expected -def test_empty_dict(): - assert remove_invalid_keys({}) == {} - - -def test_all_invalid_keys(): - input_dict = { - "_one": 1, - "__two__": 2, - "with__double": 3 - } - assert remove_invalid_keys(input_dict) == {} - - -def test_no_invalid_keys(): - input_dict = { - "a": 1, - "b": { - "c": 2 - } - } - assert remove_invalid_keys(input_dict) == input_dict - -def test_explicit_invalid_keys(): - input_dict = { - "a": 1, - "b": { - "c": 2 - } - } - expected = { - "b": { - "c": 2 - } - } +def test_remove_explicit_invalid_keys(): + input_dict = {"a": 1, "b": {"c": 2}} + expected = {"b": {"c": 2}} assert remove_invalid_keys(input_dict, excluded_keys=["a"]) == expected -def test_branch_name(): + +def test_branch_name(clean_repo): options = Options() options.commit_message = "Commit Message Test" options.debug = True options.push = False options.source_directory = "src" - repo = ProjectRepo(Path("./test_repo_cli")) - - hash = str(repo.head.commit)[:7] - active_branch = str(repo.active_branch) - new_branch = repo.get_new_output_branch_name() + hash = str(clean_repo.head.commit)[:7] + active_branch = str(clean_repo.active_branch) + new_branch = clean_repo.get_new_output_branch_name() escaped_branch = re.escape(active_branch) - pattern = rf"^\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}$" - assert re.match(pattern, new_branch), f"Branch name '{new_branch}' does not match expected format" -def test_branch_name_with_prefix(): +def test_branch_name_with_prefix(clean_repo): options = Options() options.commit_message = "Commit Message Test" options.debug = True @@ -138,16 +102,12 @@ def test_branch_name_with_prefix(): options.source_directory = "src" options.branch_prefix = "Test_Prefix" - repo = ProjectRepo(Path("./test_repo_cli")) - - hash = str(repo.head.commit)[:7] - active_branch = str(repo.active_branch) - new_branch = repo.get_new_output_branch_name(options.branch_prefix) + hash = str(clean_repo.head.commit)[:7] + active_branch = str(clean_repo.active_branch) + new_branch = clean_repo.get_new_output_branch_name(options.branch_prefix) escaped_branch = re.escape(active_branch) - pattern = rf"^Test_Prefix_\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}$" - assert re.match(pattern, new_branch), f"Branch name '{new_branch}' does not match expected format"