From 42d661775683615914fae838f5d5bd3ed9b7bb8a Mon Sep 17 00:00:00 2001 From: shitong Date: Mon, 17 Nov 2025 16:16:11 -0800 Subject: [PATCH 01/16] add make file --- .github/workflows/lint.yml | 17 +-- .github/workflows/nightly.yml | 17 ++- .github/workflows/release.yml | 19 ++- .github/workflows/static.yml | 37 +++--- .github/workflows/unit-tests.yml | 60 +++++---- .python-version | 2 + CONTRIBUTING.md | 208 ++++++++++++++++++++++++++++--- Makefile | 147 ++++++++++++++++++++++ dev-requirements.txt | 26 ++-- pyproject.toml | 59 +++++++-- test-requirements.txt | 27 ++-- 11 files changed, 488 insertions(+), 131 deletions(-) create mode 100644 .python-version create mode 100644 Makefile diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cfb057b..281cc1f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,22 +22,15 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 with: - version: "0.9.5" + enable-cache: true + cache-dependency-glob: "pyproject.toml" - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: "3.11" + run: uv python install 3.11 - name: Install dependencies run: | uv pip install --system ruff - - name: Run ruff check - run: | - ruff check torchax test test_dist - - - name: Run ruff format check - run: | - ruff format --check torchax test test_dist - + - name: Run linter + run: make lint diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7bf4fe2..9ab1f2e 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -19,27 +19,26 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: '3.11' + enable-cache: true - - name: Install build dependencies - run: pip install build + - name: Set up Python + run: uv python install 3.11 - name: Patch version for Nightly run: | python scripts/update_to_nightly_version.py - # Verify the change (Debugging) grep "__version__" torchax/__init__.py - - name: Build Wheel and Sdist - run: python -m build --wheel + - name: Build package + run: make build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - # Note: We are NOT providing a password here because we will use Trusted Publishing. + # Note: Using Trusted Publishing (passwordless) # If you must use a token, uncomment the lines below: # with: # password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7ba0ee9..ceddc39 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: # in PyPI Settings > Publishing for your release configuration. environment: name: pypi - url: https://pypi.org/p/torchax # Replace 'torchax' with your package name + url: https://pypi.org/p/torchax permissions: # MANDATORY for OIDC Trusted Publishing @@ -30,18 +30,17 @@ jobs: with: fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: '3.11' + enable-cache: true - - name: Install build dependencies - run: pip install build + - name: Set up Python + run: uv python install 3.11 - - name: Build Wheel - # We build both here, as stable releases usually offer both formats. - run: python -m build --wheel + - name: Build package + run: make build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - # Trusted Publishing handles authentication automatically. + # Trusted Publishing handles authentication automatically diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index 66fb1ba..4f4d44e 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -31,25 +31,23 @@ jobs: steps: - name: Checkout uses: actions/checkout@v5 - - name: Set up Python 3.12 - uses: actions/setup-python@v6 + + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: 3.12 + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python + run: uv python install 3.12 + - name: Setup Pages uses: actions/configure-pages@v5 + - name: Install dependencies run: | - pip install --upgrade pip - # note: installing torch and torchvision together ensures their version compatibility - pip install torch==2.8.0 torchvision --index-url https://download.pytorch.org/whl/cpu - pip install jupyter nbconvert pandas matplotlib - pip install -r test-requirements.txt - pip install -e .[cpu] - pip install mkdocs mkdocs-material mkdocs-rtd-dropdown mkdocs-jupyter - - name: build docs - working-directory: docs - run: | - mkdocs build + uv pip install --system -e ".[cpu,docs]" + - name: Generate notebooks working-directory: docs run: | @@ -59,15 +57,16 @@ jobs: jupytext --to ipynb $FILE -o "$OUTFILE" jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) done - - name: Build site - working-directory: docs - run: | - mkdocs build + + - name: Build documentation + run: make docs + - name: Upload artifact uses: actions/upload-pages-artifact@v3 with: - # Upload entire repository + # Upload docs site path: "./docs/site" + - name: Deploy to GitHub Pages id: deployment uses: actions/deploy-pages@v4 diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 0e3f106..dfdc0db 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,21 +1,23 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# This workflow will install Python dependencies and run tests with multiple Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Python package +name: Tests on: push: branches: ["main"] paths-ignore: - "docs/**" + - "**.md" pull_request: branches: ["main"] paths-ignore: - "docs/**" + - "**.md" jobs: - build: - name: Python unit tests + test: + name: Python ${{ matrix.python-version }} Tests runs-on: ubuntu-latest strategy: fail-fast: false @@ -25,39 +27,33 @@ jobs: steps: - uses: actions/checkout@v5 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: ${{ matrix.python-version }} + enable-cache: true + cache-dependency-glob: "pyproject.toml" + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + - name: Install dependencies run: | - pip install --upgrade pip - # note: installing torch and torchvision together ensures their version compatibility - pip install torch==2.8.0 torchvision --index-url https://download.pytorch.org/whl/cpu - pip install jupyter nbconvert pandas matplotlib - pip install -r test-requirements.txt - pip install -e .[cpu] - - name: Test with pytest - shell: bash + uv pip install --system -e ".[cpu,test,docs]" + + - name: Run unit tests + run: make test + continue-on-error: false + + - name: Run distributed tests run: | - export JAX_PLATFORMS=cpu - # Find all Python test files recursively - find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do - # Skip tests with known issues - if [[ "$test_file" == *"test_tf_integration.py"* ]]; then - echo "Skipping ${test_file}. TODO(https://github.com/pytorch/xla/issues/8770): Investigate" - continue - fi - echo "Running tests for $test_file" - pytest "$test_file" - done - # Run distributed tests. - XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/ - echo "Tests completed." + XLA_FLAGS=--xla_force_host_platform_device_count=4 uv run pytest test_dist/ -n 0 + - name: Test tutorials can run - shell: bash run: | export JAX_PLATFORMS=cpu - for FILE in $(find docs -name '*.py'); do - python $FILE + for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do + if [ -f "$FILE" ]; then + echo "Testing $FILE" + python "$FILE" + fi done diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..4feb2ff --- /dev/null +++ b/.python-version @@ -0,0 +1,2 @@ +3.11 + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 356bf35..e143cd1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,38 +1,208 @@ # Contributing -We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from ``good first issue`` and ``help wanted`` labels. +We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from `good first issue` and `help wanted` labels. If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. ## Developer Setup -### Mac Setup +### Prerequisites -You can develop directly on a Mac (M1) for most parts. Using the steps in the README works. Here is a condensed version for easy copy & paste: +- Python 3.11 or higher +- [uv](https://docs.astral.sh/uv/) (recommended) or pip + +### Quick Start with uv (Recommended) + +```bash +# Install uv if you haven't already +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Clone the repository +git clone https://github.com/google/torchax.git +cd torchax + +# Create a virtual environment (optional but recommended) +uv venv --python 3.11 +source .venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install with all development dependencies +make dev +# Or manually: +# uv pip install -e ".[cpu,dev,test,docs]" + +# Run tests to verify setup +make test +``` + +### Alternative Setup (without uv) + +```bash +# Create a virtual environment +python3.11 -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install the package in editable mode with dependencies +pip install -e ".[cpu,dev,test]" +``` + +### Mac Setup (M1/M2/M3) + +Development works great on Apple Silicon Macs: + +```bash +# Option 1: Using uv (recommended) +uv venv --python 3.11 +source .venv/bin/activate +make dev + +# Option 2: Using conda +conda create --name torchax python=3.11 +conda activate torchax +pip install -e ".[cpu,dev,test]" +``` + +### Hardware-Specific Installation + +```bash +# CPU (default) +uv pip install -e ".[cpu,dev,test]" + +# CUDA +uv pip install -e ".[cuda,dev,test]" + +# TPU (requires additional setup) +uv pip install -e ".[tpu,dev,test]" +``` + +## Development Workflow + +### Using Make Commands + +We provide a Makefile for common development tasks: + +```bash +# See all available commands +make help + +# Install for development +make dev + +# Check for issues +make lint + +# Auto-fix issues and format +make format + +# Format code +make format + +# Run tests +make test + +# Run all tests (including distributed) +make test-all + +# Clean build artifacts +make clean +``` + +## Testing + +```bash +# Run unit tests +make test + +# Run specific test file +JAX_PLATFORMS=cpu pytest test/test_ops.py + +# Run tests with verbose output +JAX_PLATFORMS=cpu pytest test/ -v + +# Run distributed tests +make test-all +``` + +## Project Structure + +``` +torchax/ +├── torchax/ # Main package +│ ├── ops/ # Operator implementations +│ ├── tensor.py # Core tensor functionality +│ └── ... +├── test/ # Unit tests +├── test_dist/ # Distributed tests +├── examples/ # Example scripts +├── docs/ # Documentation +├── pyproject.toml # Project configuration +└── Makefile # Development commands +``` + +## VSCode Setup + +Recommended extensions: + +- **Python** (ms-python.python) +- **Ruff** (charliermarsh.ruff) +- **Python Debugger** (ms-python.debugpy) + +### Settings + +Add to `.vscode/settings.json`: + +```json +{ + "python.defaultInterpreterPath": ".venv/bin/python", + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + } + }, + "ruff.lint.args": ["--config=pyproject.toml"], + "ruff.format.args": ["--config=pyproject.toml"] +} +``` + +## Common Issues + +### ImportError after code changes + +If you get import errors or operators not found: ```bash - conda create --name python=3.11 - conda activate - pip install --upgrade "jax[cpu]" torch - pip install -r test-requirements.txt - pip install -e . - pytest test +# Reinstall the package +pip install -e . ``` -### Ruff +### JAX backend issues + +Set the JAX platform explicitly: + +```bash +export JAX_PLATFORMS=cpu # or cuda, tpu ``` -ruff check torchax test test_dist examples --fix -ruff format torchax test test_dist examples + +### Tests failing locally but passing in CI + +Make sure you have the latest dependencies: + +```bash +make clean +make dev ``` -### VSCode +## Documentation -It is recommended to use VSCode on Mac. You can follow the instructions in the `VSCode Python tutorial `_ to set up a proper Python environment. +Build and serve documentation locally: -The recommended plugins are: +```bash +make docs-serve +``` -* VSCode's official Python plugin -* Ruff formatter -* Python Debugger +## Getting Help -You should also change the Python interpreter to point at the one in your conda environment. +- **Issues**: [GitHub Issues](https://github.com/google/torchax/issues) +- **Discussions**: [GitHub Discussions](https://github.com/google/torchax/discussions) diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..515b2b1 --- /dev/null +++ b/Makefile @@ -0,0 +1,147 @@ +# TorchAx Development Makefile +# +# This Makefile provides convenient commands for development. +# All commands use uv for fast, reliable dependency management. + +.PHONY: help install dev test test-all lint format clean build docs + +# Default target +help: + @echo "TorchAx Development Commands" + @echo "" + @echo "Setup & Installation:" + @echo " make install Install package with CPU backend" + @echo " make dev Install package with all dev dependencies" + @echo " make install-cuda Install with CUDA backend" + @echo " make install-tpu Install with TPU backend" + @echo "" + @echo "Development:" + @echo " make lint Run linters (ruff check)" + @echo " make format Format code with ruff" + @echo " make test Run unit tests (file-by-file like CI)" + @echo " make test-fast Run unit tests (parallel, faster)" + @echo " make test-all Run all tests (unit + distributed + tutorials)" + @echo " make test-gemma Run gemma tests" + @echo "" + @echo "Cleaning:" + @echo " make clean Clean build artifacts and caches" + @echo " make clean-all Deep clean including Python caches" + @echo "" + @echo "Building:" + @echo " make build Build package" + @echo " make docs Build documentation" + +# === Installation === + +install: + uv pip install -e ".[cpu]" + +dev: + uv pip install -e ".[cpu,dev,test,docs]" + +install-cuda: + uv pip install -e ".[cuda,dev,test]" + +install-tpu: + uv pip install -e ".[tpu,dev,test]" + +# === Linting & Formatting === + +lint: + @echo "Running ruff check..." + @uv run ruff check torchax test test_dist + @echo "✓ Linting passed!" + +format: + @echo "Formatting code with ruff..." + @uv run ruff check torchax test test_dist --fix + @uv run ruff format torchax test test_dist + @echo "✓ Code formatted!" + +# === Testing === + +test: + @echo "Running unit tests..." + @export JAX_PLATFORMS=cpu && \ + find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do \ + echo "Running tests in $$test_file"; \ + uv run pytest "$$test_file" -v --tb=short || exit 1; \ + done + @echo "✓ Unit tests completed!" + +test-fast: + @echo "Running unit tests (parallel)..." + @JAX_PLATFORMS=cpu uv run pytest test/ -v --tb=short -n auto + +test-all: + @echo "Running unit tests..." + @$(MAKE) test + @echo "" + @echo "Running distributed tests..." + @XLA_FLAGS=--xla_force_host_platform_device_count=4 uv run pytest test_dist/ -n 0 + @echo "" + @echo "Running tutorial tests..." + @export JAX_PLATFORMS=cpu && \ + for file in $$(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do \ + if [ -f "$$file" ]; then \ + echo "Testing $$file"; \ + python "$$file" || exit 1; \ + fi \ + done + @echo "✓ All tests completed!" + +test-gemma: + @echo "Running gemma tests..." + @JAX_PLATFORMS=cpu uv run pytest test/gemma/test_gemma.py -v + +test-coverage: + @echo "Running tests with coverage..." + @JAX_PLATFORMS=cpu uv run pytest test/ --cov=torchax --cov-report=html --cov-report=term + +# === Cleaning === + +clean: + @echo "Cleaning build artifacts..." + @rm -rf build/ dist/ *.egg-info + @find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true + @find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true + @echo "✓ Clean complete!" + +clean-all: clean + @echo "Deep cleaning..." + @find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + @find . -type f -name "*.pyc" -delete + @find . -type f -name "*.pyo" -delete + @rm -rf .coverage htmlcov/ + @echo "✓ Deep clean complete!" + +# === Building === + +build: clean + @echo "Building package..." + @uv build + @echo "✓ Build complete!" + +docs: + @echo "Building documentation..." + @cd docs && uv run mkdocs build + @echo "✓ Documentation built! Open docs/site/index.html" + +docs-serve: + @echo "Serving documentation at http://127.0.0.1:8000" + @cd docs && uv run mkdocs serve + +# === CI Simulation === + +ci: lint test + @echo "✓ CI checks passed!" + +# === Utilities === + +check-env: + @echo "Python: $$(python --version)" + @echo "uv: $$(uv --version)" + @python -c "import torch; print(f'PyTorch: {torch.__version__}')" + @python -c "import jax; print(f'JAX: {jax.__version__}')" + @python -c "import torchax; print(f'TorchAx: {torchax.__version__}')" + diff --git a/dev-requirements.txt b/dev-requirements.txt index e75e601..bcb026e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,19 @@ --f https://download.pytorch.org/whl/torch -torch==2.8.0 ; sys_platform == 'darwin' # macOS -torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU -yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` -ruff>=0.9.0 # Linter and formatter -flax==0.10.6 -jax==0.7.2 +# Development Requirements (Legacy) + +**Note**: This file is kept for backward compatibility. +New development should use: +```bash +make dev +# or +uv pip install -e ".[cpu,dev,test,docs]" +``` + +See `pyproject.toml` for the canonical list of dependencies organized by group: +- `[project.optional-dependencies.dev]` - Development tools +- `[project.optional-dependencies.test]` - Testing dependencies +- `[project.optional-dependencies.docs]` - Documentation tools + +For hardware-specific dependencies, see: +- `[project.optional-dependencies.cpu]` - CPU backend +- `[project.optional-dependencies.cuda]` - CUDA backend +- `[project.optional-dependencies.tpu]` - TPU backend diff --git a/pyproject.toml b/pyproject.toml index aa8de25..ad68f4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,10 @@ build-backend = "hatchling.build" [project] name = "torchax" -dependencies = [] +dependencies = [ + "jax[cpu]>=0.6.2", + "torch>=2.8.0", +] requires-python = ">=3.11" license = {file = "LICENSE"} dynamic = ["version"] @@ -47,17 +50,59 @@ classifiers = [ [project.urls] "Homepage" = "https://github.com/google/torchax" - +"Repository" = "https://github.com/google/torchax" +"Bug Tracker" = "https://github.com/google/torchax/issues" [tool.hatch.version] path = "torchax/__init__.py" [project.optional-dependencies] -cpu = ["jax[cpu]>=0.6.2", "jax[cpu]"] -# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html` -tpu = ["jax[cpu]>=0.6.2", "jax[tpu]"] -cuda = ["jax[cpu]>=0.6.2", "jax[cuda12]"] -odml = ["jax[cpu]>=0.6.2", "jax[cpu]"] +# Hardware-specific backends +cpu = ["jax[cpu]>=0.6.2"] +tpu = ["jax[tpu]>=0.6.2"] # Add index: https://storage.googleapis.com/libtpu-wheels/index.html +cuda = ["jax[cuda12]>=0.6.2"] +odml = ["jax[cpu]>=0.6.2"] + +# Development dependencies +dev = [ + "ruff>=0.9.0", + "yapf==0.40.2", + "flax==0.10.6", +] + +# Test dependencies +test = [ + "pytest>=8.3.5", + "pytest-xdist>=3.8.0", + "pytest-reraise>=2.1.2", + "absl-py>=2.2.2", + "immutabledict>=4.2.1", + "sentencepiece>=0.2.0", + "expecttest>=0.3.0", + "optax>=0.2.4", + "termcolor>=2.0.0", +] + +# Documentation dependencies +docs = [ + "mkdocs>=1.5.0", + "mkdocs-material>=9.0.0", + "jupyter>=1.0.0", + "nbconvert>=7.0.0", + "pandas>=2.0.0", + "matplotlib>=3.7.0", +] + +# Training/experimentation dependencies +train = [ + "tensorboard>=2.15.0", + "optax>=0.2.4", +] + +# All dependencies for complete development setup +all = [ + "torchax[dev,test,docs,train]", +] [tool.hatch.build.targets.wheel] packages = ["torchax"] diff --git a/test-requirements.txt b/test-requirements.txt index fc6f5f5..85117f0 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,16 +1,11 @@ --r dev-requirements.txt -absl-py==2.2.2 -immutabledict==4.2.1 -pytest==8.3.5 -sentencepiece -expecttest==0.3.0 -optax==0.2.4 -pytest -pytest-xdist -termcolor -jupyter -nbconvert -pandas -matplotlib -tensorboard -pytest-reraise +# Test Requirements (Legacy) + +**Note**: This file is kept for backward compatibility. +New development should use: +```bash +make dev +# or +uv pip install -e ".[cpu,dev,test]" +``` + +See `pyproject.toml` `[project.optional-dependencies.test]` for the canonical list of test dependencies. From 744ab4e07bea7761350c35c2abfa7d1c72d7cd01 Mon Sep 17 00:00:00 2001 From: shitong Date: Mon, 17 Nov 2025 16:23:21 -0800 Subject: [PATCH 02/16] fix --- .github/workflows/lint.yml | 6 ++++-- .github/workflows/nightly.yml | 5 ++++- .github/workflows/release.yml | 3 +++ .github/workflows/static.yml | 9 ++++++--- .github/workflows/unit-tests.yml | 9 ++++++--- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 281cc1f..04f3732 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -28,9 +28,11 @@ jobs: - name: Set up Python run: uv python install 3.11 + - name: Create virtual environment + run: uv venv + - name: Install dependencies - run: | - uv pip install --system ruff + run: uv pip install ruff - name: Run linter run: make lint diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 9ab1f2e..ae77d78 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -27,9 +27,12 @@ jobs: - name: Set up Python run: uv python install 3.11 + - name: Create virtual environment + run: uv venv + - name: Patch version for Nightly run: | - python scripts/update_to_nightly_version.py + uv run python scripts/update_to_nightly_version.py # Verify the change (Debugging) grep "__version__" torchax/__init__.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ceddc39..927ef1e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -38,6 +38,9 @@ jobs: - name: Set up Python run: uv python install 3.11 + - name: Create virtual environment + run: uv venv + - name: Build package run: make build diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index 4f4d44e..9bee94b 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -41,12 +41,15 @@ jobs: - name: Set up Python run: uv python install 3.12 + - name: Create virtual environment + run: uv venv + - name: Setup Pages uses: actions/configure-pages@v5 - name: Install dependencies run: | - uv pip install --system -e ".[cpu,docs]" + uv pip install -e ".[cpu,docs]" - name: Generate notebooks working-directory: docs @@ -54,8 +57,8 @@ jobs: export JAX_PLATFORMS=cpu for FILE in $(find docs -name '*.py'); do export OUTFILE="${FILE%.*}.ipynb" - jupytext --to ipynb $FILE -o "$OUTFILE" - jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) + uv run jupytext --to ipynb $FILE -o "$OUTFILE" + uv run jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) done - name: Build documentation diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index dfdc0db..767a342 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -17,7 +17,7 @@ on: jobs: test: - name: Python ${{ matrix.python-version }} Tests + name: Python ${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -36,9 +36,12 @@ jobs: - name: Set up Python ${{ matrix.python-version }} run: uv python install ${{ matrix.python-version }} + - name: Create virtual environment + run: uv venv --python ${{ matrix.python-version }} + - name: Install dependencies run: | - uv pip install --system -e ".[cpu,test,docs]" + uv pip install -e ".[cpu,test,docs]" - name: Run unit tests run: make test @@ -54,6 +57,6 @@ jobs: for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do if [ -f "$FILE" ]; then echo "Testing $FILE" - python "$FILE" + uv run python "$FILE" fi done From f94992a5124e945c81fdc2e63ed5d29e42c460a2 Mon Sep 17 00:00:00 2001 From: shitong Date: Mon, 17 Nov 2025 16:31:14 -0800 Subject: [PATCH 03/16] fix --- .github/workflows/unit-tests.yml | 4 ++-- Makefile | 14 +++++--------- pyproject.toml | 1 + 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 767a342..97ce60e 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies and run tests with multiple Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Tests +name: Python unit tests on: push: @@ -17,7 +17,7 @@ on: jobs: test: - name: Python ${{ matrix.python-version }} + name: Python unit tests (${{ matrix.python-version }}) runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/Makefile b/Makefile index 515b2b1..4996c2c 100644 --- a/Makefile +++ b/Makefile @@ -65,38 +65,34 @@ test: @export JAX_PLATFORMS=cpu && \ find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do \ echo "Running tests in $$test_file"; \ - uv run pytest "$$test_file" -v --tb=short || exit 1; \ + uv run --frozen pytest "$$test_file" -v --tb=short || exit 1; \ done @echo "✓ Unit tests completed!" test-fast: @echo "Running unit tests (parallel)..." - @JAX_PLATFORMS=cpu uv run pytest test/ -v --tb=short -n auto + @JAX_PLATFORMS=cpu uv run --frozen pytest test/ -v --tb=short -n auto test-all: @echo "Running unit tests..." @$(MAKE) test @echo "" @echo "Running distributed tests..." - @XLA_FLAGS=--xla_force_host_platform_device_count=4 uv run pytest test_dist/ -n 0 + @XLA_FLAGS=--xla_force_host_platform_device_count=4 uv run --frozen pytest test_dist/ -n 0 @echo "" @echo "Running tutorial tests..." @export JAX_PLATFORMS=cpu && \ for file in $$(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do \ if [ -f "$$file" ]; then \ echo "Testing $$file"; \ - python "$$file" || exit 1; \ + uv run --frozen python "$$file" || exit 1; \ fi \ done @echo "✓ All tests completed!" -test-gemma: - @echo "Running gemma tests..." - @JAX_PLATFORMS=cpu uv run pytest test/gemma/test_gemma.py -v - test-coverage: @echo "Running tests with coverage..." - @JAX_PLATFORMS=cpu uv run pytest test/ --cov=torchax --cov-report=html --cov-report=term + @JAX_PLATFORMS=cpu uv run --frozen pytest test/ --cov=torchax --cov-report=html --cov-report=term # === Cleaning === diff --git a/pyproject.toml b/pyproject.toml index ad68f4f..ed1a2e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ test = [ "expecttest>=0.3.0", "optax>=0.2.4", "termcolor>=2.0.0", + "flax>=0.10.6", # Required by checkpoint.py ] # Documentation dependencies From 008b786202b017d66d2b15b79a1a8ca5889deeed Mon Sep 17 00:00:00 2001 From: shitong Date: Mon, 17 Nov 2025 16:37:22 -0800 Subject: [PATCH 04/16] fix test --- .github/workflows/unit-tests.yml | 11 ++++++++--- Makefile | 10 +++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 97ce60e..cc5955f 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -41,22 +41,27 @@ jobs: - name: Install dependencies run: | + source .venv/bin/activate uv pip install -e ".[cpu,test,docs]" - name: Run unit tests - run: make test + run: | + source .venv/bin/activate + make test continue-on-error: false - name: Run distributed tests run: | - XLA_FLAGS=--xla_force_host_platform_device_count=4 uv run pytest test_dist/ -n 0 + source .venv/bin/activate + XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest test_dist/ -n 0 - name: Test tutorials can run run: | + source .venv/bin/activate export JAX_PLATFORMS=cpu for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do if [ -f "$FILE" ]; then echo "Testing $FILE" - uv run python "$FILE" + python "$FILE" fi done diff --git a/Makefile b/Makefile index 4996c2c..dd84e22 100644 --- a/Makefile +++ b/Makefile @@ -65,34 +65,34 @@ test: @export JAX_PLATFORMS=cpu && \ find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do \ echo "Running tests in $$test_file"; \ - uv run --frozen pytest "$$test_file" -v --tb=short || exit 1; \ + pytest "$$test_file" -v --tb=short || exit 1; \ done @echo "✓ Unit tests completed!" test-fast: @echo "Running unit tests (parallel)..." - @JAX_PLATFORMS=cpu uv run --frozen pytest test/ -v --tb=short -n auto + @JAX_PLATFORMS=cpu pytest test/ -v --tb=short -n auto test-all: @echo "Running unit tests..." @$(MAKE) test @echo "" @echo "Running distributed tests..." - @XLA_FLAGS=--xla_force_host_platform_device_count=4 uv run --frozen pytest test_dist/ -n 0 + @XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest test_dist/ -n 0 @echo "" @echo "Running tutorial tests..." @export JAX_PLATFORMS=cpu && \ for file in $$(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do \ if [ -f "$$file" ]; then \ echo "Testing $$file"; \ - uv run --frozen python "$$file" || exit 1; \ + python "$$file" || exit 1; \ fi \ done @echo "✓ All tests completed!" test-coverage: @echo "Running tests with coverage..." - @JAX_PLATFORMS=cpu uv run --frozen pytest test/ --cov=torchax --cov-report=html --cov-report=term + @JAX_PLATFORMS=cpu pytest test/ --cov=torchax --cov-report=html --cov-report=term # === Cleaning === From 8bb6ae8bd48f1964f8b448a8725bdaa5a35b6c69 Mon Sep 17 00:00:00 2001 From: shitong Date: Mon, 17 Nov 2025 16:50:53 -0800 Subject: [PATCH 05/16] pin torch and jax version --- pyproject.toml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ed1a2e4..ab373ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,9 @@ build-backend = "hatchling.build" [project] name = "torchax" dependencies = [ - "jax[cpu]>=0.6.2", - "torch>=2.8.0", + "jax[cpu]==0.7.2", + "torch==2.8.0 ; sys_platform == 'darwin'", # macOS + "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Non-macOS (CPU-only) ] requires-python = ">=3.11" license = {file = "LICENSE"} @@ -58,10 +59,10 @@ path = "torchax/__init__.py" [project.optional-dependencies] # Hardware-specific backends -cpu = ["jax[cpu]>=0.6.2"] -tpu = ["jax[tpu]>=0.6.2"] # Add index: https://storage.googleapis.com/libtpu-wheels/index.html -cuda = ["jax[cuda12]>=0.6.2"] -odml = ["jax[cpu]>=0.6.2"] +cpu = ["jax[cpu]==0.7.2"] +tpu = ["jax[tpu]==0.7.2"] # Add index: https://storage.googleapis.com/libtpu-wheels/index.html +cuda = ["jax[cuda12]==0.7.2"] +odml = ["jax[cpu]==0.7.2"] # Development dependencies dev = [ From 855d6a66972fff032521505e2630992cb02974c2 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 10:30:13 -0800 Subject: [PATCH 06/16] seperate test dependency --- .github/workflows/unit-tests.yml | 2 +- CONTRIBUTING.md | 23 +++++++++++++---------- Makefile | 25 +++++++++---------------- dev-requirements.txt | 8 ++++---- pyproject.toml | 31 ++++++++++++++++++------------- test-requirements.txt | 4 ++-- 6 files changed, 47 insertions(+), 46 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index cc5955f..8af7286 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -42,7 +42,7 @@ jobs: - name: Install dependencies run: | source .venv/bin/activate - uv pip install -e ".[cpu,test,docs]" + make install-test - name: Run unit tests run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e143cd1..a777016 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,9 +26,9 @@ uv venv --python 3.11 source .venv/bin/activate # On Windows: .venv\Scripts\activate # Install with all development dependencies -make dev +make install # Or manually: -# uv pip install -e ".[cpu,dev,test,docs]" +# uv pip install -e ".[cpu,dev,docs]" # Run tests to verify setup make test @@ -53,25 +53,28 @@ Development works great on Apple Silicon Macs: # Option 1: Using uv (recommended) uv venv --python 3.11 source .venv/bin/activate -make dev +make install # Option 2: Using conda conda create --name torchax python=3.11 conda activate torchax -pip install -e ".[cpu,dev,test]" +pip install -e ".[cpu,dev,docs]" ``` ### Hardware-Specific Installation ```bash -# CPU (default) -uv pip install -e ".[cpu,dev,test]" +# CPU (default - flexible versions for development) +make install + +# CPU with pinned test versions (exactly like CI) +make install-test # CUDA -uv pip install -e ".[cuda,dev,test]" +make install-cuda # TPU (requires additional setup) -uv pip install -e ".[tpu,dev,test]" +make install-tpu ``` ## Development Workflow @@ -185,13 +188,13 @@ Set the JAX platform explicitly: export JAX_PLATFORMS=cpu # or cuda, tpu ``` -### Tests failing locally but passing in CI +### If tests fail locally but passing in CI Make sure you have the latest dependencies: ```bash make clean -make dev +make install-test # Use exact CI versions ``` ## Documentation diff --git a/Makefile b/Makefile index dd84e22..b881c6c 100644 --- a/Makefile +++ b/Makefile @@ -3,17 +3,15 @@ # This Makefile provides convenient commands for development. # All commands use uv for fast, reliable dependency management. -.PHONY: help install dev test test-all lint format clean build docs +.PHONY: help install install-test test test-all lint format clean build docs # Default target help: @echo "TorchAx Development Commands" @echo "" - @echo "Setup & Installation:" - @echo " make install Install package with CPU backend" - @echo " make dev Install package with all dev dependencies" - @echo " make install-cuda Install with CUDA backend" - @echo " make install-tpu Install with TPU backend" + @echo "Setup & Installation (uses uv):" + @echo " make install Install for development (flexible versions)" + @echo " make install-test Install with pinned test versions (like CI)" @echo "" @echo "Development:" @echo " make lint Run linters (ruff check)" @@ -21,7 +19,6 @@ help: @echo " make test Run unit tests (file-by-file like CI)" @echo " make test-fast Run unit tests (parallel, faster)" @echo " make test-all Run all tests (unit + distributed + tutorials)" - @echo " make test-gemma Run gemma tests" @echo "" @echo "Cleaning:" @echo " make clean Clean build artifacts and caches" @@ -34,16 +31,12 @@ help: # === Installation === install: - uv pip install -e ".[cpu]" + @echo "Installing with flexible versions for development..." + uv pip install -e ".[cpu,dev,docs]" -dev: - uv pip install -e ".[cpu,dev,test,docs]" - -install-cuda: - uv pip install -e ".[cuda,dev,test]" - -install-tpu: - uv pip install -e ".[tpu,dev,test]" +install-test: + @echo "Installing with pinned test versions (like CI)..." + pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu # === Linting & Formatting === diff --git a/dev-requirements.txt b/dev-requirements.txt index bcb026e..1b7609c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,14 +3,14 @@ **Note**: This file is kept for backward compatibility. New development should use: ```bash -make dev +make install # or -uv pip install -e ".[cpu,dev,test,docs]" +uv pip install -e ".[cpu,dev,docs]" ``` See `pyproject.toml` for the canonical list of dependencies organized by group: -- `[project.optional-dependencies.dev]` - Development tools -- `[project.optional-dependencies.test]` - Testing dependencies +- `[project.optional-dependencies.dev]` - Development tools (flexible versions) +- `[project.optional-dependencies.test]` - Testing dependencies (pinned versions) - `[project.optional-dependencies.docs]` - Documentation tools For hardware-specific dependencies, see: diff --git a/pyproject.toml b/pyproject.toml index ab373ca..a45686a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,8 @@ build-backend = "hatchling.build" [project] name = "torchax" dependencies = [ - "jax[cpu]==0.7.2", - "torch==2.8.0 ; sys_platform == 'darwin'", # macOS - "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Non-macOS (CPU-only) + "jax[cpu]>=0.6.2", # Flexible - works with newer versions + "torch>=2.4.0", # Flexible - works with newer versions ] requires-python = ">=3.11" license = {file = "LICENSE"} @@ -58,21 +57,28 @@ classifiers = [ path = "torchax/__init__.py" [project.optional-dependencies] -# Hardware-specific backends -cpu = ["jax[cpu]==0.7.2"] -tpu = ["jax[tpu]==0.7.2"] # Add index: https://storage.googleapis.com/libtpu-wheels/index.html -cuda = ["jax[cuda12]==0.7.2"] -odml = ["jax[cpu]==0.7.2"] +# Hardware-specific backends (flexible versions) +cpu = ["jax[cpu]>=0.6.2"] +tpu = ["jax[tpu]>=0.6.2"] # Add index: https://storage.googleapis.com/libtpu-wheels/index.html +cuda = ["jax[cuda12]>=0.6.2"] +odml = ["jax[cpu]>=0.6.2"] -# Development dependencies +# Development dependencies (flexible versions for latest features) dev = [ "ruff>=0.9.0", - "yapf==0.40.2", - "flax==0.10.6", + "yapf>=0.40.0", + "flax>=0.10.0", ] -# Test dependencies +# Test dependencies (PINNED versions for reproducible CI) test = [ + # Pinned PyTorch - platform specific + "torch==2.8.0 ; sys_platform == 'darwin'", # macOS from PyPI + "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index + # Other pinned versions + "jax[cpu]==0.7.2", + "flax==0.10.6", + # Test framework "pytest>=8.3.5", "pytest-xdist>=3.8.0", "pytest-reraise>=2.1.2", @@ -82,7 +88,6 @@ test = [ "expecttest>=0.3.0", "optax>=0.2.4", "termcolor>=2.0.0", - "flax>=0.10.6", # Required by checkpoint.py ] # Documentation dependencies diff --git a/test-requirements.txt b/test-requirements.txt index 85117f0..b85ca6b 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,9 +3,9 @@ **Note**: This file is kept for backward compatibility. New development should use: ```bash -make dev +make install-test # or -uv pip install -e ".[cpu,dev,test]" +pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu ``` See `pyproject.toml` `[project.optional-dependencies.test]` for the canonical list of test dependencies. From 2183224e3107ef77f55f1a958c1147f990e99cdd Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 10:33:03 -0800 Subject: [PATCH 07/16] lint logic --- Makefile | 6 +++--- pyproject.toml | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index b881c6c..7abeef2 100644 --- a/Makefile +++ b/Makefile @@ -42,13 +42,13 @@ install-test: lint: @echo "Running ruff check..." - @uv run ruff check torchax test test_dist + @ruff check torchax test test_dist @echo "✓ Linting passed!" format: @echo "Formatting code with ruff..." - @uv run ruff check torchax test test_dist --fix - @uv run ruff format torchax test test_dist + @ruff check torchax test test_dist --fix + @ruff format torchax test test_dist @echo "✓ Code formatted!" # === Testing === diff --git a/pyproject.toml b/pyproject.toml index a45686a..b85b1d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ train = [ ] # All dependencies for complete development setup +# Note: Install torch separately due to platform-specific builds all = [ "torchax[dev,test,docs,train]", ] From f8803a9e90090b4f4d5185926baaa6020f25a38a Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 11:01:42 -0800 Subject: [PATCH 08/16] install uv --- .github/workflows/lint.yml | 8 +- .github/workflows/nightly.yml | 7 +- .github/workflows/release.yml | 4 +- .github/workflows/static.yml | 10 ++- .gitignore | 1 + CONTRIBUTING.md | 44 +++++++--- Makefile | 156 ++++++++++++++++++++++------------ 7 files changed, 150 insertions(+), 80 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 04f3732..4b04557 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -28,11 +28,5 @@ jobs: - name: Set up Python run: uv python install 3.11 - - name: Create virtual environment - run: uv venv - - - name: Install dependencies - run: uv pip install ruff - - name: Run linter - run: make lint + run: make lint-check diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index ae77d78..da572cf 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -32,12 +32,15 @@ jobs: - name: Patch version for Nightly run: | - uv run python scripts/update_to_nightly_version.py + source .venv/bin/activate + uv run --no-project python scripts/update_to_nightly_version.py # Verify the change (Debugging) grep "__version__" torchax/__init__.py - name: Build package - run: make build + run: | + source .venv/bin/activate + make build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 927ef1e..5111ec3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -42,7 +42,9 @@ jobs: run: uv venv - name: Build package - run: make build + run: | + source .venv/bin/activate + make build - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index 9bee94b..b64708b 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -49,20 +49,24 @@ jobs: - name: Install dependencies run: | + source .venv/bin/activate uv pip install -e ".[cpu,docs]" - name: Generate notebooks working-directory: docs run: | + source ../.venv/bin/activate export JAX_PLATFORMS=cpu for FILE in $(find docs -name '*.py'); do export OUTFILE="${FILE%.*}.ipynb" - uv run jupytext --to ipynb $FILE -o "$OUTFILE" - uv run jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) + uv run --no-project jupytext --to ipynb $FILE -o "$OUTFILE" + uv run --no-project jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) done - name: Build documentation - run: make docs + run: | + source .venv/bin/activate + make docs - name: Upload artifact uses: actions/upload-pages-artifact@v3 diff --git a/.gitignore b/.gitignore index a6fc076..5fac29d 100644 --- a/.gitignore +++ b/.gitignore @@ -153,6 +153,7 @@ activemq-data/ .env .envrc .venv +.local env/ venv/ ENV/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a777016..f450ec4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,32 +9,49 @@ If you plan to contribute new features, utility functions or extensions to the c ### Prerequisites - Python 3.11 or higher -- [uv](https://docs.astral.sh/uv/) (recommended) or pip +- Git -### Quick Start with uv (Recommended) +### Quick Start (Recommended) -```bash -# Install uv if you haven't already -curl -LsSf https://astral.sh/uv/install.sh | sh +The Makefile handles everything, including uv installation: +```bash # Clone the repository git clone https://github.com/google/torchax.git cd torchax -# Create a virtual environment (optional but recommended) +# Install uv locally (if not already in system) +make install-uv + +# Create a virtual environment uv venv --python 3.11 source .venv/bin/activate # On Windows: .venv\Scripts\activate # Install with all development dependencies make install -# Or manually: -# uv pip install -e ".[cpu,dev,docs]" -# Run tests to verify setup +# Verify setup +make check-env + +# Run tests make test ``` -### Alternative Setup (without uv) +### Alternative: System-wide uv + +If you prefer uv in your PATH: + +```bash +# Install uv system-wide +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Then follow the same steps above +uv venv --python 3.11 +source .venv/bin/activate +make install +``` + +### Without uv (Traditional) ```bash # Create a virtual environment @@ -42,7 +59,7 @@ python3.11 -m venv .venv source .venv/bin/activate # On Windows: .venv\Scripts\activate # Install the package in editable mode with dependencies -pip install -e ".[cpu,dev,test]" +pip install -e ".[cpu,dev,docs]" ``` ### Mac Setup (M1/M2/M3) @@ -50,12 +67,13 @@ pip install -e ".[cpu,dev,test]" Development works great on Apple Silicon Macs: ```bash -# Option 1: Using uv (recommended) +# Recommended: Using uv +make install-uv # Install uv if needed uv venv --python 3.11 source .venv/bin/activate make install -# Option 2: Using conda +# Alternative: Using conda conda create --name torchax python=3.11 conda activate torchax pip install -e ".[cpu,dev,docs]" diff --git a/Makefile b/Makefile index 7abeef2..e18d846 100644 --- a/Makefile +++ b/Makefile @@ -1,79 +1,128 @@ # TorchAx Development Makefile -# -# This Makefile provides convenient commands for development. -# All commands use uv for fast, reliable dependency management. -.PHONY: help install install-test test test-all lint format clean build docs +# Colors for output +ORANGE := \033[0;33m +GREEN := \033[0;32m +RESET := \033[0m + +# UV binary - use system uv if available, otherwise use local +UV := $(shell command -v uv 2>/dev/null || echo "./.local/bin/uv") + +# Prefer managed Python from uv +export UV_PYTHON_PREFERENCE := only-managed + +.PHONY: help install install-test lint lint-check format test test-fast test-all clean clean-all build docs # Default target help: - @echo "TorchAx Development Commands" + @echo "$(ORANGE)TorchAx Development Commands$(RESET)" @echo "" - @echo "Setup & Installation (uses uv):" - @echo " make install Install for development (flexible versions)" - @echo " make install-test Install with pinned test versions (like CI)" + @echo "Setup & Installation:" + @echo " make install-uv Install uv locally (if not in system)" + @echo " make install Install for development (flexible versions)" + @echo " make install-test Install with pinned test versions (like CI)" + @echo " make install-cuda Install with CUDA backend" + @echo " make install-tpu Install with TPU backend" @echo "" @echo "Development:" - @echo " make lint Run linters (ruff check)" - @echo " make format Format code with ruff" + @echo " make lint Auto-fix and format code" + @echo " make lint-check Check code without modifying" + @echo " make format Format code only" @echo " make test Run unit tests (file-by-file like CI)" - @echo " make test-fast Run unit tests (parallel, faster)" - @echo " make test-all Run all tests (unit + distributed + tutorials)" + @echo " make test-fast Run unit tests (parallel)" + @echo " make test-all Run all tests" @echo "" @echo "Cleaning:" - @echo " make clean Clean build artifacts and caches" - @echo " make clean-all Deep clean including Python caches" + @echo " make clean Clean build artifacts" + @echo " make clean-all Deep clean" @echo "" @echo "Building:" @echo " make build Build package" @echo " make docs Build documentation" +# === Ensure uv is available === + +.local/bin/uv: + @echo "$(ORANGE)Installing uv locally to .local/bin/...$(RESET)" + @mkdir -p .local/bin + @curl -LsSf https://astral.sh/uv/install.sh | sh + @if [ -f ~/.local/bin/uv ]; then \ + cp ~/.local/bin/uv .local/bin/uv; \ + echo "$(GREEN)✓ uv installed to .local/bin/uv$(RESET)"; \ + elif [ -f ~/.cargo/bin/uv ]; then \ + cp ~/.cargo/bin/uv .local/bin/uv; \ + echo "$(GREEN)✓ uv installed to .local/bin/uv$(RESET)"; \ + else \ + echo "Error: uv installation failed - check if it's in your PATH"; \ + exit 1; \ + fi + +install-uv: .local/bin/uv + @echo "$(GREEN)✓ uv is ready at .local/bin/uv$(RESET)" + # === Installation === install: - @echo "Installing with flexible versions for development..." - uv pip install -e ".[cpu,dev,docs]" + @echo "$(ORANGE)Installing for development (flexible versions)...$(RESET)" + @$(UV) pip install -e ".[cpu,dev,docs]" + @echo "$(GREEN)✓ Installation complete!$(RESET)" install-test: - @echo "Installing with pinned test versions (like CI)..." - pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu + @echo "$(ORANGE)Installing with pinned test versions (like CI)...$(RESET)" + @pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu + @echo "$(GREEN)✓ Installation complete!$(RESET)" + +install-cuda: + @$(UV) pip install -e ".[cuda,dev,test]" + +install-tpu: + @$(UV) pip install -e ".[tpu,dev,test]" # === Linting & Formatting === lint: - @echo "Running ruff check..." - @ruff check torchax test test_dist - @echo "✓ Linting passed!" + @echo "$(ORANGE)1. ==== Ruff format ====$(RESET)" + @$(UV) run --no-project ruff format torchax test test_dist + @echo "$(ORANGE)2. ==== Ruff check & fix ====$(RESET)" + @$(UV) run --no-project ruff check torchax test test_dist --fix + @echo "$(GREEN)✓ Code formatted and linted!$(RESET)" + +lint-check: + @echo "$(ORANGE)1. ==== Ruff format check ====$(RESET)" + @$(UV) run --no-project ruff format --check torchax test test_dist + @echo "$(ORANGE)2. ==== Ruff check ====$(RESET)" + @$(UV) run --no-project ruff check torchax test test_dist + @echo "$(GREEN)✓ Linting passed!$(RESET)" format: - @echo "Formatting code with ruff..." - @ruff check torchax test test_dist --fix - @ruff format torchax test test_dist - @echo "✓ Code formatted!" + @echo "$(ORANGE)Formatting code with ruff...$(RESET)" + @$(UV) run --no-project ruff format torchax test test_dist + @echo "$(GREEN)✓ Code formatted!$(RESET)" # === Testing === test: - @echo "Running unit tests..." + @echo "$(ORANGE)Running unit tests...$(RESET)" @export JAX_PLATFORMS=cpu && \ find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do \ - echo "Running tests in $$test_file"; \ - pytest "$$test_file" -v --tb=short || exit 1; \ + echo "Running tests for $$test_file"; \ + pytest "$$test_file" || exit 1; \ done - @echo "✓ Unit tests completed!" + @echo "$(GREEN)✓ Unit tests completed!$(RESET)" test-fast: - @echo "Running unit tests (parallel)..." + @echo "$(ORANGE)Running unit tests (parallel)...$(RESET)" @JAX_PLATFORMS=cpu pytest test/ -v --tb=short -n auto + @echo "$(GREEN)✓ Tests completed!$(RESET)" test-all: - @echo "Running unit tests..." + @echo "$(ORANGE)Running all tests...$(RESET)" @$(MAKE) test @echo "" - @echo "Running distributed tests..." + @echo "$(ORANGE)Running distributed tests...$(RESET)" @XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest test_dist/ -n 0 @echo "" - @echo "Running tutorial tests..." + @echo "$(ORANGE)Running tutorial tests...$(RESET)" @export JAX_PLATFORMS=cpu && \ for file in $$(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do \ if [ -f "$$file" ]; then \ @@ -81,56 +130,55 @@ test-all: python "$$file" || exit 1; \ fi \ done - @echo "✓ All tests completed!" + @echo "$(GREEN)✓ All tests completed!$(RESET)" test-coverage: - @echo "Running tests with coverage..." + @echo "$(ORANGE)Running tests with coverage...$(RESET)" @JAX_PLATFORMS=cpu pytest test/ --cov=torchax --cov-report=html --cov-report=term # === Cleaning === clean: - @echo "Cleaning build artifacts..." + @echo "$(ORANGE)Cleaning build artifacts...$(RESET)" @rm -rf build/ dist/ *.egg-info @find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true - @echo "✓ Clean complete!" + @echo "$(GREEN)✓ Clean complete!$(RESET)" clean-all: clean - @echo "Deep cleaning..." + @echo "$(ORANGE)Deep cleaning...$(RESET)" @find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true @find . -type f -name "*.pyc" -delete @find . -type f -name "*.pyo" -delete @rm -rf .coverage htmlcov/ - @echo "✓ Deep clean complete!" + @echo "$(GREEN)✓ Deep clean complete!$(RESET)" # === Building === build: clean - @echo "Building package..." - @uv build - @echo "✓ Build complete!" + @echo "$(ORANGE)Building package...$(RESET)" + @$(UV) build + @echo "$(GREEN)✓ Build complete!$(RESET)" docs: - @echo "Building documentation..." - @cd docs && uv run mkdocs build - @echo "✓ Documentation built! Open docs/site/index.html" + @echo "$(ORANGE)Building documentation...$(RESET)" + @cd docs && $(UV) run mkdocs build + @echo "$(GREEN)✓ Documentation built!$(RESET)" docs-serve: - @echo "Serving documentation at http://127.0.0.1:8000" - @cd docs && uv run mkdocs serve + @echo "$(ORANGE)Serving documentation at http://127.0.0.1:8000$(RESET)" + @cd docs && $(UV) run mkdocs serve # === CI Simulation === -ci: lint test - @echo "✓ CI checks passed!" +ci: lint-check test + @echo "$(GREEN)✓ CI checks passed!$(RESET)" # === Utilities === check-env: @echo "Python: $$(python --version)" - @echo "uv: $$(uv --version)" - @python -c "import torch; print(f'PyTorch: {torch.__version__}')" - @python -c "import jax; print(f'JAX: {jax.__version__}')" - @python -c "import torchax; print(f'TorchAx: {torchax.__version__}')" - + @echo "uv: $$($(UV) --version 2>/dev/null || echo 'not available')" + @python -c "import torch; print(f'PyTorch: {torch.__version__}')" 2>/dev/null || echo "PyTorch: not installed" + @python -c "import jax; print(f'JAX: {jax.__version__}')" 2>/dev/null || echo "JAX: not installed" + @python -c "import torchax; print(f'TorchAx: {torchax.__version__}')" 2>/dev/null || echo "TorchAx: not installed" From db8682e47c3a35bcdafa11207f739c0b417878d6 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 11:06:58 -0800 Subject: [PATCH 09/16] fix --- .github/workflows/nightly.yml | 2 +- .github/workflows/static.yml | 4 ++-- Makefile | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index da572cf..db71e26 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -33,7 +33,7 @@ jobs: - name: Patch version for Nightly run: | source .venv/bin/activate - uv run --no-project python scripts/update_to_nightly_version.py + python scripts/update_to_nightly_version.py # Verify the change (Debugging) grep "__version__" torchax/__init__.py diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index b64708b..ccbd62b 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -59,8 +59,8 @@ jobs: export JAX_PLATFORMS=cpu for FILE in $(find docs -name '*.py'); do export OUTFILE="${FILE%.*}.ipynb" - uv run --no-project jupytext --to ipynb $FILE -o "$OUTFILE" - uv run --no-project jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) + uv tool run jupytext --to ipynb $FILE -o "$OUTFILE" + uv tool run jupyter nbconvert --to notebook --execute $OUTFILE --output $(basename $OUTFILE) done - name: Build documentation diff --git a/Makefile b/Makefile index e18d846..dea705c 100644 --- a/Makefile +++ b/Makefile @@ -82,21 +82,21 @@ install-tpu: lint: @echo "$(ORANGE)1. ==== Ruff format ====$(RESET)" - @$(UV) run --no-project ruff format torchax test test_dist + @$(UV) tool run ruff format torchax test test_dist @echo "$(ORANGE)2. ==== Ruff check & fix ====$(RESET)" - @$(UV) run --no-project ruff check torchax test test_dist --fix + @$(UV) tool run ruff check torchax test test_dist --fix @echo "$(GREEN)✓ Code formatted and linted!$(RESET)" lint-check: @echo "$(ORANGE)1. ==== Ruff format check ====$(RESET)" - @$(UV) run --no-project ruff format --check torchax test test_dist + @$(UV) tool run ruff format --check torchax test test_dist @echo "$(ORANGE)2. ==== Ruff check ====$(RESET)" - @$(UV) run --no-project ruff check torchax test test_dist + @$(UV) tool run ruff check torchax test test_dist @echo "$(GREEN)✓ Linting passed!$(RESET)" format: @echo "$(ORANGE)Formatting code with ruff...$(RESET)" - @$(UV) run --no-project ruff format torchax test test_dist + @$(UV) tool run ruff format torchax test test_dist @echo "$(GREEN)✓ Code formatted!$(RESET)" # === Testing === From 2ea22af5a48e38813ea9d87eae3a70354780ef59 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 11:28:30 -0800 Subject: [PATCH 10/16] cancle running job --- .github/workflows/lint.yml | 5 +++++ .github/workflows/unit-tests.yml | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4b04557..d836fec 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,6 +12,11 @@ on: - "docs/**" - "**.md" +# Cancel in-progress runs when a new commit is pushed +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: lint: name: Lint with Ruff diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 8af7286..5f71d24 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -15,6 +15,11 @@ on: - "docs/**" - "**.md" +# Cancel in-progress runs when a new commit is pushed +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: test: name: Python unit tests (${{ matrix.python-version }}) From ca1e022cd4f81f8558bd7cb57544f527d55f13cd Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 11:32:42 -0800 Subject: [PATCH 11/16] fix tutorials tests --- .github/workflows/unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 5f71d24..62e3b7e 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -67,6 +67,6 @@ jobs: for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do if [ -f "$FILE" ]; then echo "Testing $FILE" - python "$FILE" + .venv/bin/python "$FILE" fi done From 714be4a922bcc15a797a74cad1c74dca871410e8 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 11:54:56 -0800 Subject: [PATCH 12/16] debug --- .github/workflows/unit-tests.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 62e3b7e..6b97ca6 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -49,11 +49,11 @@ jobs: source .venv/bin/activate make install-test - - name: Run unit tests - run: | - source .venv/bin/activate - make test - continue-on-error: false + # - name: Run unit tests + # run: | + # source .venv/bin/activate + # make test + # continue-on-error: false - name: Run distributed tests run: | @@ -62,7 +62,11 @@ jobs: - name: Test tutorials can run run: | - source .venv/bin/activate + echo "=== Checking venv contents ===" + .venv/bin/pip list | grep -E "torch|jax|torchax" || echo "No packages found!" + echo "=== Python path ===" + .venv/bin/python -c "import sys; print('\n'.join(sys.path))" + echo "=== Running tutorials ===" export JAX_PLATFORMS=cpu for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do if [ -f "$FILE" ]; then From 07dabd143bd84ca11f4486418f2204a0a7bccf6a Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 13:45:05 -0800 Subject: [PATCH 13/16] fix --- .github/workflows/unit-tests.yml | 8 ++------ Makefile | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 6b97ca6..a3f11dc 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -62,15 +62,11 @@ jobs: - name: Test tutorials can run run: | - echo "=== Checking venv contents ===" - .venv/bin/pip list | grep -E "torch|jax|torchax" || echo "No packages found!" - echo "=== Python path ===" - .venv/bin/python -c "import sys; print('\n'.join(sys.path))" - echo "=== Running tutorials ===" + source .venv/bin/activate export JAX_PLATFORMS=cpu for FILE in $(find docs/docs/tutorials -name '*.py' 2>/dev/null || true); do if [ -f "$FILE" ]; then echo "Testing $FILE" - .venv/bin/python "$FILE" + python "$FILE" fi done diff --git a/Makefile b/Makefile index dea705c..d792290 100644 --- a/Makefile +++ b/Makefile @@ -69,7 +69,7 @@ install: install-test: @echo "$(ORANGE)Installing with pinned test versions (like CI)...$(RESET)" - @pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu + @$(UV) pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu @echo "$(GREEN)✓ Installation complete!$(RESET)" install-cuda: From bc127e583b76d3c7687323361ca802d108193853 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 14:04:24 -0800 Subject: [PATCH 14/16] fix --- .github/workflows/unit-tests.yml | 10 +++++----- pyproject.toml | 8 +++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index a3f11dc..5f71d24 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -49,11 +49,11 @@ jobs: source .venv/bin/activate make install-test - # - name: Run unit tests - # run: | - # source .venv/bin/activate - # make test - # continue-on-error: false + - name: Run unit tests + run: | + source .venv/bin/activate + make test + continue-on-error: false - name: Run distributed tests run: | diff --git a/pyproject.toml b/pyproject.toml index b85b1d1..774b2ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,9 +72,11 @@ dev = [ # Test dependencies (PINNED versions for reproducible CI) test = [ - # Pinned PyTorch - platform specific - "torch==2.8.0 ; sys_platform == 'darwin'", # macOS from PyPI - "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index + # Pinned PyTorch - platform specific (installed together for compatibility) + "torch==2.8.0 ; sys_platform == 'darwin'", # macOS from PyPI + "torchvision==0.19.0 ; sys_platform == 'darwin'", # macOS from PyPI + "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index + "torchvision==0.19.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index # Other pinned versions "jax[cpu]==0.7.2", "flax==0.10.6", From 34f24f433a518075c424fc49138c474c3b87f103 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 14:14:16 -0800 Subject: [PATCH 15/16] torchvison --- Makefile | 2 +- pyproject.toml | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index d792290..f3400b6 100644 --- a/Makefile +++ b/Makefile @@ -69,7 +69,7 @@ install: install-test: @echo "$(ORANGE)Installing with pinned test versions (like CI)...$(RESET)" - @$(UV) pip install -e ".[test,docs]" --extra-index-url https://download.pytorch.org/whl/cpu + @$(UV) pip install -e ".[test,dev,docs]" --extra-index-url https://download.pytorch.org/whl/cpu @echo "$(GREEN)✓ Installation complete!$(RESET)" install-cuda: diff --git a/pyproject.toml b/pyproject.toml index 774b2ed..8d59804 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,15 +68,14 @@ dev = [ "ruff>=0.9.0", "yapf>=0.40.0", "flax>=0.10.0", + "torchvision>=0.19.0", # For tutorials/examples ] # Test dependencies (PINNED versions for reproducible CI) test = [ - # Pinned PyTorch - platform specific (installed together for compatibility) + # Pinned PyTorch - platform specific "torch==2.8.0 ; sys_platform == 'darwin'", # macOS from PyPI - "torchvision==0.19.0 ; sys_platform == 'darwin'", # macOS from PyPI "torch==2.8.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index - "torchvision==0.19.0+cpu ; sys_platform != 'darwin'", # Linux from PyTorch index # Other pinned versions "jax[cpu]==0.7.2", "flax==0.10.6", From 1bcc194fc502060dbeceb653a3c62705018af079 Mon Sep 17 00:00:00 2001 From: shitong Date: Tue, 2 Dec 2025 15:10:32 -0800 Subject: [PATCH 16/16] add dependens --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8d59804..a0cdfd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,12 +83,14 @@ test = [ "pytest>=8.3.5", "pytest-xdist>=3.8.0", "pytest-reraise>=2.1.2", + "setuptools>=70.0.0", # Required by torch.testing._internal "absl-py>=2.2.2", "immutabledict>=4.2.1", "sentencepiece>=0.2.0", "expecttest>=0.3.0", "optax>=0.2.4", "termcolor>=2.0.0", + "tensorboard>=2.15.0", ] # Documentation dependencies