Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchprime/launcher/save_hf_assets_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def _upload_directory_to_gcs(local_path: Path, gcs_path: str):
"""Uploads the contents of a local directory to GCS using gsutil.
"""Uploads the contents of a local directory to GCS using gcloud storage.

Args:
local_path: The local directory whose contents will be uploaded.
Expand All @@ -42,7 +42,7 @@ def _upload_directory_to_gcs(local_path: Path, gcs_path: str):
raise ValueError("GCS path must start with gs://")

logger.info(f"Uploading contents of '{local_path}' to '{gcs_path}'...")
command = ["gsutil", "-m", "cp", "-r", f"{str(local_path).rstrip('/')}/*", gcs_path]
command = ["gcloud", "storage", "cp", "--recursive", f"{str(local_path).rstrip('/')}/*", gcs_path]
try:
subprocess.run(command, check=True, capture_output=True, text=True)
logger.info(f"Successfully uploaded assets to {gcs_path}.")
Expand Down
39 changes: 19 additions & 20 deletions torchprime/torch_xla_models/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,24 +354,24 @@ def convert_to_safetensors_on_cpu(model: torch.nn.Module, save_dir: Path) -> Non
def maybe_move_to_mounted_gcs(tmp_dir: Path | None, save_dir: str):
"""
If tmp_dir is provided, move *.safetensors files and index file
from tmp_dir to save_dir using gsutil or shutil.
from tmp_dir to save_dir using gcloud storage or shutil.

Args:
tmp_dir (Path): Local directory containing files to upload.
save_dir (Path): Destination directory (e.g., /tmp/gcs-mount/...)
"""
if tmp_dir:
save_dir.mkdir(parents=True, exist_ok=True)
if shutil.which("gsutil"):
if shutil.which("gcloud"):
try:
# gsutil seems to give 8x speedup over shutil.copy2
logger.info("Using gsutil for upload")
# gcloud storage seems to give 8x speedup over shutil.copy2
logger.info("Using gcloud storage for upload")
move_to_mounted_gcs_gsutil(tmp_dir, save_dir)
return
except subprocess.CalledProcessError as e:
logger.warning("gsutil failed: %s. Falling back to shutil-based copy.", str(e))
logger.warning("gcloud storage failed: %s. Falling back to shutil-based copy.", str(e))
else:
logger.info("gsutil not found. Falling back to shutil-based copy.")
logger.info("gcloud not found. Falling back to shutil-based copy.")
move_to_mounted_gcs_shutil(tmp_dir, save_dir)
else:
logger.warning("No tmp_dir provided, checkpoint already saved in save_dir.")
Expand All @@ -380,26 +380,25 @@ def maybe_move_to_mounted_gcs(tmp_dir: Path | None, save_dir: str):
def move_to_mounted_gcs_gsutil(work_dir: Path, save_dir: str):
"""
Moves *.safetensors files and index file from work_dir to save_dir,
using gsutil for efficient upload to a mounted GCS bucket.
using gcloud storage for efficient upload to a mounted GCS bucket.

Args:
work_dir (Path): Local directory containing files to upload.
save_dir (Path): Destination directory (e.g., /tmp/gcs-mount/...)
"""
save_dir.mkdir(parents=True, exist_ok=True)
cmd = [
"gsutil",
"-m", # Enables parallel (multi-threaded) execution for faster copying
"-q", # Suppresses all output unless errors occur (quiet mode)
"cp", # Copy command
"-n", # No-clobber: skip files that already exist at the destination
"gcloud",
"storage",
"cp",
"--no-clobber", # No-clobber: skip files that already exist at the destination
*(str(p) for p in work_dir.glob("*.safetensors")), # All .safetensors files to copy
str(save_dir) + "/", # Destination directory in the GCS bucket
]
cmd_idx = [
"gsutil",
"-q", # Quiet mode
"cp", # Copy command
"gcloud",
"storage",
"cp",
str(work_dir / "model.safetensors.index.json"), # Source index file
str(save_dir) + "/", # Destination directory
]
Expand Down Expand Up @@ -492,7 +491,7 @@ def local_path_from_gcs(path_or_repo: str, temp_dir: str | None = None):

If the input `path_or_repo` starts with 'gs://', this function will download
the contents of the GCS directory to a temporary local directory using the
`gsutil` command-line tool. The local directory will be automatically cleaned
`gcloud storage` command-line tool. The local directory will be automatically cleaned
up when the context is exited.

If the input is not a GCS path, it is assumed to be a local path or a
Expand All @@ -512,16 +511,16 @@ def local_path_from_gcs(path_or_repo: str, temp_dir: str | None = None):
yield path_or_repo
return

if not shutil.which("gsutil"):
if not shutil.which("gcloud"):
raise RuntimeError(
"gsutil command not found, but is required for downloading from GCS. "
"gcloud command not found, but is required for downloading from GCS. "
"Please install the Google Cloud SDK."
)

local_dir = tempfile.mkdtemp(dir=temp_dir)
try:
gcs_path = path_or_repo.rstrip("/") + "/*"
command = ["gsutil", "-m", "-q", "cp", "-r", gcs_path, local_dir]
command = ["gcloud", "storage", "cp", "--recursive", gcs_path, local_dir]
subprocess.run(command, check=True, capture_output=True, text=True)
logger.info(
"Successfully downloaded files from %s to temporary directory %s.",
Expand All @@ -531,7 +530,7 @@ def local_path_from_gcs(path_or_repo: str, temp_dir: str | None = None):

yield local_dir
except subprocess.CalledProcessError as e:
logger.error("gsutil download failed for %s. Stderr:\n%s", path_or_repo, e.stderr)
logger.error("gcloud storage download failed for %s. Stderr:\n%s", path_or_repo, e.stderr)
raise
finally:
logger.info(f"Cleaning up temporary directory: {local_dir}")
Expand Down