diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cfb057b..d836fec 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,6 +12,11 @@ on: - "docs/**" - "**.md" +# Cancel in-progress runs when a new commit is pushed +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: lint: name: Lint with Ruff @@ -22,22 +27,11 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 with: - version: "0.9.5" + enable-cache: true + cache-dependency-glob: "pyproject.toml" - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: "3.11" - - - name: Install dependencies - run: | - uv pip install --system ruff - - - name: Run ruff check - run: | - ruff check torchax test test_dist - - - name: Run ruff format check - run: | - ruff format --check torchax test test_dist + run: uv python install 3.11 + - name: Run linter + run: make lint-check diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7bf4fe2..db71e26 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -19,27 +19,32 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: '3.11' + enable-cache: true + + - name: Set up Python + run: uv python install 3.11 - - name: Install build dependencies - run: pip install build + - name: Create virtual environment + run: uv venv - name: Patch version for Nightly run: | + source .venv/bin/activate python scripts/update_to_nightly_version.py - # Verify the change (Debugging) grep "__version__" torchax/__init__.py - - name: Build Wheel and Sdist - run: python -m build --wheel + - name: Build package + run: | + source .venv/bin/activate + make build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - # Note: We are NOT providing a password here because we will use Trusted Publishing. + # Note: Using Trusted Publishing (passwordless) # If you must use a token, uncomment the lines below: # with: # password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7ba0ee9..5111ec3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: # in PyPI Settings > Publishing for your release configuration. environment: name: pypi - url: https://pypi.org/p/torchax # Replace 'torchax' with your package name + url: https://pypi.org/p/torchax permissions: # MANDATORY for OIDC Trusted Publishing @@ -30,18 +30,22 @@ jobs: with: fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: '3.11' + enable-cache: true + + - name: Set up Python + run: uv python install 3.11 - - name: Install build dependencies - run: pip install build + - name: Create virtual environment + run: uv venv - - name: Build Wheel - # We build both here, as stable releases usually offer both formats. - run: python -m build --wheel + - name: Build package + run: | + source .venv/bin/activate + make build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - # Trusted Publishing handles authentication automatically. + # Trusted Publishing handles authentication automatically diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index 66fb1ba..ccbd62b 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -31,43 +31,49 @@ jobs: steps: - name: Checkout uses: actions/checkout@v5 - - name: Set up Python 3.12 - uses: actions/setup-python@v6 + + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: 3.12 + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python + run: uv python install 3.12 + + - name: Create virtual environment + run: uv venv + - name: Setup Pages uses: actions/configure-pages@v5 + - name: Install dependencies run: | - pip install --upgrade pip - # note: installing torch and torchvision together ensures their version compatibility - pip install torch==2.8.0 torchvision --index-url https://download.pytorch.org/whl/cpu - pip install jupyter nbconvert pandas matplotlib - pip install -r test-requirements.txt - pip install -e .[cpu] - pip install mkdocs mkdocs-material mkdocs-rtd-dropdown mkdocs-jupyter - - name: build docs - working-directory: docs - run: | - mkdocs build + source .venv/bin/activate + uv pip install -e ".[cpu,docs]" + - name: Generate notebooks working-directory: docs run: | + source ../.venv/bin/activate export JAX_PLATFORMS=cpu for FILE in $(find docs -name '*.py'); do export OUTFILE="${FILE%.*}.ipynb" - jupytext --to ipynb $FILE -o "$OUTFILE" - jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) + uv tool run jupytext --to ipynb $FILE -o "$OUTFILE" + uv tool run jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) done - - name: Build site - working-directory: docs + + - name: Build documentation run: | - mkdocs build + source .venv/bin/activate + make docs + - name: Upload artifact uses: actions/upload-pages-artifact@v3 with: - # Upload entire repository + # Upload docs site path: "./docs/site" + - name: Deploy to GitHub Pages id: deployment uses: actions/deploy-pages@v4 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 0e3f106..5f71d24 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,21 +1,28 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# This workflow will install Python dependencies and run tests with multiple Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Python package +name: Python unit tests on: push: branches: ["main"] paths-ignore: - "docs/**" + - "**.md" pull_request: branches: ["main"] paths-ignore: - "docs/**" + - "**.md" + +# Cancel in-progress runs when a new commit is pushed +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: - build: - name: Python unit tests + test: + name: Python unit tests (${{ matrix.python-version }}) runs-on: ubuntu-latest strategy: fail-fast: false @@ -25,39 +32,41 @@ jobs: steps: - uses: actions/checkout@v5 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: ${{ matrix.python-version }} + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + + - name: Create virtual environment + run: uv venv --python ${{ matrix.python-version }} + - name: Install dependencies run: | - pip install --upgrade pip - # note: installing torch and torchvision together ensures their version compatibility - pip install torch==2.8.0 torchvision --index-url https://download.pytorch.org/whl/cpu - pip install jupyter nbconvert pandas matplotlib - pip install -r test-requirements.txt - pip install -e .[cpu] - - name: Test with pytest - shell: bash + source .venv/bin/activate + make install-test + + - name: Run unit tests + run: | + source .venv/bin/activate + make test + continue-on-error: false + + - name: Run distributed tests run: | - export JAX_PLATFORMS=cpu - # Find all Python test files recursively - find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do - # Skip tests with known issues - if [[ "$test_file" == *"test_tf_integration.py"* ]]; then - echo "Skipping ${test_file}. TODO(https://github.com/pytorch/xla/issues/8770): Investigate" - continue - fi - echo "Running tests for $test_file" - pytest "$test_file" - done - # Run distributed tests. - XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/ - echo "Tests completed." + source .venv/bin/activate + XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest test_dist/ -n 0 + - name: Test tutorials can run - shell: bash run: | + source .venv/bin/activate export JAX_PLATFORMS=cpu - for FILE in $(find docs -name '*.py'); do - python $FILE + for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do + if [ -f "$FILE" ]; then + echo "Testing $FILE" + python "$FILE" + fi done diff --git a/.gitignore b/.gitignore index a6fc076..5fac29d 100644 --- a/.gitignore +++ b/.gitignore @@ -153,6 +153,7 @@ activemq-data/ .env .envrc .venv +.local env/ venv/ ENV/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..4feb2ff --- /dev/null +++ b/.python-version @@ -0,0 +1,2 @@ +3.11 + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 356bf35..f450ec4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,38 +1,229 @@ # Contributing -We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from ``good first issue`` and ``help wanted`` labels. +We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from `good first issue` and `help wanted` labels. If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. ## Developer Setup -### Mac Setup +### Prerequisites -You can develop directly on a Mac (M1) for most parts. Using the steps in the README works. Here is a condensed version for easy copy & paste: +- Python 3.11 or higher +- Git + +### Quick Start (Recommended) + +The Makefile handles everything, including uv installation: ```bash - conda create --name python=3.11 - conda activate - pip install --upgrade "jax[cpu]" torch - pip install -r test-requirements.txt - pip install -e . - pytest test +# Clone the repository +git clone https://github.com/google/torchax.git +cd torchax + +# Install uv locally (if not already in system) +make install-uv + +# Create a virtual environment +uv venv --python 3.11 +source .venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install with all development dependencies +make install + +# Verify setup +make check-env + +# Run tests +make test ``` -### Ruff +### Alternative: System-wide uv + +If you prefer uv in your PATH: + +```bash +# Install uv system-wide +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Then follow the same steps above +uv venv --python 3.11 +source .venv/bin/activate +make install ``` -ruff check torchax test test_dist examples --fix -ruff format torchax test test_dist examples + +### Without uv (Traditional) + +```bash +# Create a virtual environment +python3.11 -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install the package in editable mode with dependencies +pip install -e ".[cpu,dev,docs]" ``` -### VSCode +### Mac Setup (M1/M2/M3) + +Development works great on Apple Silicon Macs: + +```bash +# Recommended: Using uv +make install-uv # Install uv if needed +uv venv --python 3.11 +source .venv/bin/activate +make install + +# Alternative: Using conda +conda create --name torchax python=3.11 +conda activate torchax +pip install -e ".[cpu,dev,docs]" +``` -It is recommended to use VSCode on Mac. You can follow the instructions in the `VSCode Python tutorial `_ to set up a proper Python environment. +### Hardware-Specific Installation + +```bash +# CPU (default - flexible versions for development) +make install + +# CPU with pinned test versions (exactly like CI) +make install-test + +# CUDA +make install-cuda + +# TPU (requires additional setup) +make install-tpu +``` + +## Development Workflow + +### Using Make Commands + +We provide a Makefile for common development tasks: + +```bash +# See all available commands +make help + +# Install for development +make dev + +# Check for issues +make lint + +# Auto-fix issues and format +make format + +# Format code +make format + +# Run tests +make test + +# Run all tests (including distributed) +make test-all + +# Clean build artifacts +make clean +``` -The recommended plugins are: +## Testing + +```bash +# Run unit tests +make test + +# Run specific test file +JAX_PLATFORMS=cpu pytest test/test_ops.py + +# Run tests with verbose output +JAX_PLATFORMS=cpu pytest test/ -v + +# Run distributed tests +make test-all +``` + +## Project Structure + +``` +torchax/ +├── torchax/ # Main package +│ ├── ops/ # Operator implementations +│ ├── tensor.py # Core tensor functionality +│ └── ... +├── test/ # Unit tests +├── test_dist/ # Distributed tests +├── examples/ # Example scripts +├── docs/ # Documentation +├── pyproject.toml # Project configuration +└── Makefile # Development commands +``` + +## VSCode Setup + +Recommended extensions: + +- **Python** (ms-python.python) +- **Ruff** (charliermarsh.ruff) +- **Python Debugger** (ms-python.debugpy) + +### Settings + +Add to `.vscode/settings.json`: + +```json +{ + "python.defaultInterpreterPath": ".venv/bin/python", + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + } + }, + "ruff.lint.args": ["--config=pyproject.toml"], + "ruff.format.args": ["--config=pyproject.toml"] +} +``` + +## Common Issues + +### ImportError after code changes + +If you get import errors or operators not found: + +```bash +# Reinstall the package +pip install -e . +``` + +### JAX backend issues + +Set the JAX platform explicitly: + +```bash +export JAX_PLATFORMS=cpu # or cuda, tpu +``` + +### If tests fail locally but passing in CI + +Make sure you have the latest dependencies: + +```bash +make clean +make install-test # Use exact CI versions +``` + +## Documentation + +Build and serve documentation locally: + +```bash +make docs-serve +``` -* VSCode's official Python plugin -* Ruff formatter -* Python Debugger +## Getting Help -You should also change the Python interpreter to point at the one in your conda environment. +- **Issues**: [GitHub Issues](https://github.com/google/torchax/issues) +- **Discussions**: [GitHub Discussions](https://github.com/google/torchax/discussions) diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..f3400b6 --- /dev/null +++ b/Makefile @@ -0,0 +1,184 @@ +# TorchAx Development Makefile + +# Colors for output +ORANGE := \033[0;33m +GREEN := \033[0;32m +RESET := \033[0m + +# UV binary - use system uv if available, otherwise use local +UV := $(shell command -v uv 2>/dev/null || echo "./.local/bin/uv") + +# Prefer managed Python from uv +export UV_PYTHON_PREFERENCE := only-managed + +.PHONY: help install install-test lint lint-check format test test-fast test-all clean clean-all build docs + +# Default target +help: + @echo "$(ORANGE)TorchAx Development Commands$(RESET)" + @echo "" + @echo "Setup & Installation:" + @echo " make install-uv Install uv locally (if not in system)" + @echo " make install Install for development (flexible versions)" + @echo " make install-test Install with pinned test versions (like CI)" + @echo " make install-cuda Install with CUDA backend" + @echo " make install-tpu Install with TPU backend" + @echo "" + @echo "Development:" + @echo " make lint Auto-fix and format code" + @echo " make lint-check Check code without modifying" + @echo " make format Format code only" + @echo " make test Run unit tests (file-by-file like CI)" + @echo " make test-fast Run unit tests (parallel)" + @echo " make test-all Run all tests" + @echo "" + @echo "Cleaning:" + @echo " make clean Clean build artifacts" + @echo " make clean-all Deep clean" + @echo "" + @echo "Building:" + @echo " make build Build package" + @echo " make docs Build documentation" + +# === Ensure uv is available === + +.local/bin/uv: + @echo "$(ORANGE)Installing uv locally to .local/bin/...$(RESET)" + @mkdir -p .local/bin + @curl -LsSf https://astral.sh/uv/install.sh | sh + @if [ -f ~/.local/bin/uv ]; then \ + cp ~/.local/bin/uv .local/bin/uv; \ + echo "$(GREEN)✓ uv installed to .local/bin/uv$(RESET)"; \ + elif [ -f ~/.cargo/bin/uv ]; then \ + cp ~/.cargo/bin/uv .local/bin/uv; \ + echo "$(GREEN)✓ uv installed to .local/bin/uv$(RESET)"; \ + else \ + echo "Error: uv installation failed - check if it's in your PATH"; \ + exit 1; \ + fi + +install-uv: .local/bin/uv + @echo "$(GREEN)✓ uv is ready at .local/bin/uv$(RESET)" + +# === Installation === + +install: + @echo "$(ORANGE)Installing for development (flexible versions)...$(RESET)" + @$(UV) pip install -e ".[cpu,dev,docs]" + @echo "$(GREEN)✓ Installation complete!$(RESET)" + +install-test: + @echo "$(ORANGE)Installing with pinned test versions (like CI)...$(RESET)" + @$(UV) pip install -e ".[test,dev,docs]" --extra-index-url https://download.pytorch.org/whl/cpu + @echo "$(GREEN)✓ Installation complete!$(RESET)" + +install-cuda: + @$(UV) pip install -e ".[cuda,dev,test]" + +install-tpu: + @$(UV) pip install -e ".[tpu,dev,test]" + +# === Linting & Formatting === + +lint: + @echo "$(ORANGE)1. ==== Ruff format ====$(RESET)" + @$(UV) tool run ruff format torchax test test_dist + @echo "$(ORANGE)2. ==== Ruff check & fix ====$(RESET)" + @$(UV) tool run ruff check torchax test test_dist --fix + @echo "$(GREEN)✓ Code formatted and linted!$(RESET)" + +lint-check: + @echo "$(ORANGE)1. ==== Ruff format check ====$(RESET)" + @$(UV) tool run ruff format --check torchax test test_dist + @echo "$(ORANGE)2. ==== Ruff check ====$(RESET)" + @$(UV) tool run ruff check torchax test test_dist + @echo "$(GREEN)✓ Linting passed!$(RESET)" + +format: + @echo "$(ORANGE)Formatting code with ruff...$(RESET)" + @$(UV) tool run ruff format torchax test test_dist + @echo "$(GREEN)✓ Code formatted!$(RESET)" + +# === Testing === + +test: + @echo "$(ORANGE)Running unit tests...$(RESET)" + @export JAX_PLATFORMS=cpu && \ + find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do \ + echo "Running tests for $$test_file"; \ + pytest "$$test_file" || exit 1; \ + done + @echo "$(GREEN)✓ Unit tests completed!$(RESET)" + +test-fast: + @echo "$(ORANGE)Running unit tests (parallel)...$(RESET)" + @JAX_PLATFORMS=cpu pytest test/ -v --tb=short -n auto + @echo "$(GREEN)✓ Tests completed!$(RESET)" + +test-all: + @echo "$(ORANGE)Running all tests...$(RESET)" + @$(MAKE) test + @echo "" + @echo "$(ORANGE)Running distributed tests...$(RESET)" + @XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest test_dist/ -n 0 + @echo "" + @echo "$(ORANGE)Running tutorial tests...$(RESET)" + @export JAX_PLATFORMS=cpu && \ + for file in $$(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do \ + if [ -f "$$file" ]; then \ + echo "Testing $$file"; \ + python "$$file" || exit 1; \ + fi \ + done + @echo "$(GREEN)✓ All tests completed!$(RESET)" + +test-coverage: + @echo "$(ORANGE)Running tests with coverage...$(RESET)" + @JAX_PLATFORMS=cpu pytest test/ --cov=torchax --cov-report=html --cov-report=term + +# === Cleaning === + +clean: + @echo "$(ORANGE)Cleaning build artifacts...$(RESET)" + @rm -rf build/ dist/ *.egg-info + @find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true + @find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true + @echo "$(GREEN)✓ Clean complete!$(RESET)" + +clean-all: clean + @echo "$(ORANGE)Deep cleaning...$(RESET)" + @find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + @find . -type f -name "*.pyc" -delete + @find . -type f -name "*.pyo" -delete + @rm -rf .coverage htmlcov/ + @echo "$(GREEN)✓ Deep clean complete!$(RESET)" + +# === Building === + +build: clean + @echo "$(ORANGE)Building package...$(RESET)" + @$(UV) build + @echo "$(GREEN)✓ Build complete!$(RESET)" + +docs: + @echo "$(ORANGE)Building documentation...$(RESET)" + @cd docs && $(UV) run mkdocs build + @echo "$(GREEN)✓ Documentation built!$(RESET)" + +docs-serve: + @echo "$(ORANGE)Serving documentation at http://127.0.0.1:8000$(RESET)" + @cd docs && $(UV) run mkdocs serve + +# === CI Simulation === + +ci: lint-check test + @echo "$(GREEN)✓ CI checks passed!$(RESET)" + +# === Utilities === + +check-env: + @echo "Python: $$(python --version)" + @echo "uv: $$($(UV) --version 2>/dev/null || echo 'not available')" + @python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null || echo "PyTorch: not installed" + @python -c "import jax; print(f'JAX: {jax.__version__}')" 2>/dev/null || echo "JAX: not installed" + @python -c "import torchax; print(f'TorchAx: {torchax.__version__}')" 2>/dev/null || echo "TorchAx: not installed" diff --git a/dev-requirements.txt b/dev-requirements.txt index e75e601..1b7609c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,19 @@ --f https://download.pytorch.org/whl/torch -torch==2.8.0 ; sys_platform == 'darwin' # macOS -torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU -yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` -ruff>=0.9.0 # Linter and formatter -flax==0.10.6 -jax==0.7.2 +# Development Requirements (Legacy) + +**Note**: This file is kept for backward compatibility. +New development should use: +```bash +make install +# or +uv pip install -e ".[cpu,dev,docs]" +``` + +See `pyproject.toml` for the canonical list of dependencies organized by group: +- `[project.optional-dependencies.dev]` - Development tools (flexible versions) +- `[project.optional-dependencies.test]` - Testing dependencies (pinned versions) +- `[project.optional-dependencies.docs]` - Documentation tools + +For hardware-specific dependencies, see: +- `[project.optional-dependencies.cpu]` - CPU backend +- `[project.optional-dependencies.cuda]` - CUDA backend +- `[project.optional-dependencies.tpu]` - TPU backend diff --git a/pyproject.toml b/pyproject.toml index aa8de25..a0cdfd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,10 @@ build-backend = "hatchling.build" [project] name = "torchax" -dependencies = [] +dependencies = [ + "jax[cpu]>=0.6.2", # Flexible - works with newer versions + "torch>=2.4.0", # Flexible - works with newer versions +] requires-python = ">=3.11" license = {file = "LICENSE"} dynamic = ["version"] @@ -47,17 +50,70 @@ classifiers = [ [project.urls] "Homepage" = "https://github.com/google/torchax" - +"Repository" = "https://github.com/google/torchax" +"Bug Tracker" = "https://github.com/google/torchax/issues" [tool.hatch.version] path = "torchax/__init__.py" [project.optional-dependencies] -cpu = ["jax[cpu]>=0.6.2", "jax[cpu]"] -# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html` -tpu = ["jax[cpu]>=0.6.2", "jax[tpu]"] -cuda = ["jax[cpu]>=0.6.2", "jax[cuda12]"] -odml = ["jax[cpu]>=0.6.2", "jax[cpu]"] +# Hardware-specific backends (flexible versions) +cpu = ["jax[cpu]>=0.6.2"] +tpu = ["jax[tpu]>=0.6.2"] # Add index: https://storage.googleapis.com/libtpu-wheels/index.html +cuda = ["jax[cuda12]>=0.6.2"] +odml = ["jax[cpu]>=0.6.2"] + +# Development dependencies (flexible versions for latest features) +dev = [ + "ruff>=0.9.0", + "yapf>=0.40.0", + "flax>=0.10.0", + "torchvision>=0.19.0", # For tutorials/examples +] + +# Test dependencies (PINNED versions for reproducible CI) +test = [ + # Pinned PyTorch - platform specific + "torch==2.8.0 ; sys_platform == 'darwin'", # macOS from PyPI + "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index + # Other pinned versions + "jax[cpu]==0.7.2", + "flax==0.10.6", + # Test framework + "pytest>=8.3.5", + "pytest-xdist>=3.8.0", + "pytest-reraise>=2.1.2", + "setuptools>=70.0.0", # Required by torch.testing._internal + "absl-py>=2.2.2", + "immutabledict>=4.2.1", + "sentencepiece>=0.2.0", + "expecttest>=0.3.0", + "optax>=0.2.4", + "termcolor>=2.0.0", + "tensorboard>=2.15.0", +] + +# Documentation dependencies +docs = [ + "mkdocs>=1.5.0", + "mkdocs-material>=9.0.0", + "jupyter>=1.0.0", + "nbconvert>=7.0.0", + "pandas>=2.0.0", + "matplotlib>=3.7.0", +] + +# Training/experimentation dependencies +train = [ + "tensorboard>=2.15.0", + "optax>=0.2.4", +] + +# All dependencies for complete development setup +# Note: Install torch separately due to platform-specific builds +all = [ + "torchax[dev,test,docs,train]", +] [tool.hatch.build.targets.wheel] packages = ["torchax"] diff --git a/test-requirements.txt b/test-requirements.txt index fc6f5f5..b85ca6b 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,16 +1,11 @@ --r dev-requirements.txt -absl-py==2.2.2 -immutabledict==4.2.1 -pytest==8.3.5 -sentencepiece -expecttest==0.3.0 -optax==0.2.4 -pytest -pytest-xdist -termcolor -jupyter -nbconvert -pandas -matplotlib -tensorboard -pytest-reraise +# Test Requirements (Legacy) + +**Note**: This file is kept for backward compatibility. +New development should use: +```bash +make install-test +# or +pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu +``` + +See `pyproject.toml` `[project.optional-dependencies.test]` for the canonical list of test dependencies.