diff --git a/pyproject.toml b/pyproject.toml index d9cae07..9044e78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "eed-basic-utils" -version = "1.0.0" +version = "1.1.0" description = "Add your description here" readme = "README.md" authors = [ diff --git a/src/eed_basic_utils/os/__init__.py b/src/eed_basic_utils/os/__init__.py index fad5a30..4ae55e8 100644 --- a/src/eed_basic_utils/os/__init__.py +++ b/src/eed_basic_utils/os/__init__.py @@ -1,5 +1,7 @@ __all__ = [ "file_exists", + "get_git_root", ] from .file_exists import file_exists +from .get_git_root import get_git_root diff --git a/src/eed_basic_utils/os/get_git_root.py b/src/eed_basic_utils/os/get_git_root.py new file mode 100644 index 0000000..a1071f5 --- /dev/null +++ b/src/eed_basic_utils/os/get_git_root.py @@ -0,0 +1,35 @@ +from pathlib import Path +from typing import Optional + + +def get_git_root(path: Optional[Path] = None) -> Path | None: + """ + Returns the root directory of the git repository containing the given path. + + Parameters + ---------- + path : Optional[Path], optional + The starting path to search for the git root. If None, uses the current working directory. + + Returns + ------- + Path or None + The root directory of the git repository, or None if not found. + + Notes + ----- + Searches upwards from the given path until a directory containing a `.git` folder is found. + """ + if path is None: + path = Path.cwd() + + if (path / ".git").exists(): + return path + + git_root = next(iter(path.parents)) + while not (git_root / ".git").exists() and git_root != git_root.parent: + git_root = git_root.parent + if not (git_root / ".git").exists(): + return None + else: + return git_root diff --git a/tests/conftest.py b/tests/conftest.py index 4aa3591..3901b33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import os from collections.abc import Iterator from pathlib import Path @@ -8,9 +7,17 @@ @pytest.fixture(scope="session") def cfg_test() -> Iterator[dict]: - home_dir = os.path.expanduser("~") + home_dir = Path.home() + cwd = Path.cwd() + for item in cwd.iterdir(): + if item.is_dir() and item.name[0] != ".": + subfolder = item + break + # Ensure cwd is a valid directory cfg_test = { "home_dir": home_dir, + "cwd": cwd, + "cwd_subfolder": cwd / subfolder, # Example subfolder, adjust as needed } yield cfg_test diff --git a/tests/os/test_get_git_root.py b/tests/os/test_get_git_root.py new file mode 100644 index 0000000..ef68754 --- /dev/null +++ b/tests/os/test_get_git_root.py @@ -0,0 +1,33 @@ +from eed_basic_utils.os import get_git_root +import pytest + + +@pytest.mark.parametrize( + "path, expected", + [ + ("cwd", "eed_basic_utils"), + ("cwd_subfolder", "eed_basic_utils"), + ("home_dir", None), + ], + ids=[ + "cwd = git root", + "subfolder", + "non_git_path", + ], +) +def test_get_git_root(cfg_test, path, expected): + """Test the get_git_root function.""" + if path == "cwd": + path = cfg_test["cwd"] + elif path == "cwd_subfolder": + path = cfg_test["cwd_subfolder"] + elif path == "home_dir": + path = cfg_test["home_dir"] + + git_root = get_git_root(path) + if git_root is None: + assert expected is git_root + else: + assert git_root.name == expected + assert (git_root / ".git").exists() + assert git_root.is_dir() diff --git a/uv.lock b/uv.lock index e79575c..4a4bf28 100644 --- a/uv.lock +++ b/uv.lock @@ -31,7 +31,7 @@ wheels = [ [[package]] name = "eed-basic-utils" -version = "0.1.2" +version = "1.1.0" source = { editable = "." } [package.dev-dependencies]