diff --git a/src/madengine/tools/run_models.py b/src/madengine/tools/run_models.py index a620d96f..f77a14de 100644 --- a/src/madengine/tools/run_models.py +++ b/src/madengine/tools/run_models.py @@ -631,6 +631,13 @@ def run_model_impl( print(f"MAD_CONTAINER_IMAGE is {run_details.docker_image}") print(f"Warning: User override MAD_CONTAINER_IMAGE. Model support on image not guaranteed.") + persistent_hf_cache_dir = self.context.ctx.get("persistent_hf_cache_dir", None) + hf_models_cache_volume_name = "hf_models_cache_volume" + + # Create Docker volume for models caching + if persistent_hf_cache_dir: + self.console.sh(f"docker volume create {hf_models_cache_volume_name}") + # prepare docker run options gpu_vendor = self.context.ctx["gpu_vendor"] docker_options = "" @@ -695,13 +702,17 @@ def run_model_impl( # Must set env vars and mounts at the end docker_options += self.get_env_arg(run_env) docker_options += self.get_mount_arg(mount_datapaths) + + if persistent_hf_cache_dir: + docker_options += f" -v {hf_models_cache_volume_name}:{persistent_hf_cache_dir} " + docker_options += f" {run_details.additional_docker_run_options}" # if --shm-size is set, remove --ipc=host if "SHM_SIZE" in self.context.ctx: docker_options = docker_options.replace("--ipc=host", "") - print(docker_options) + print(f"Docker options: {docker_options}") # get machine name run_details.machine_name = self.console.sh("hostname") @@ -746,6 +757,10 @@ def run_model_impl( model_docker.sh("rm -rf " + model_dir, timeout=240) + if persistent_hf_cache_dir: + print("HF models cache directory content:") + model_docker.sh(f"ls -la {persistent_hf_cache_dir}") + # set safe.directory for workspace model_docker.sh("git config --global --add safe.directory /myworkspace") @@ -841,34 +856,19 @@ def run_model_impl( # run model test_start_time = time.time() + model_args = self.context.ctx.get("model_args", info["args"]) + + if persistent_hf_cache_dir: + model_args += f" --hf_cache_dir {persistent_hf_cache_dir} " + + run_command = f"cd {model_dir} && {script_name} {model_args}" + if not self.args.skip_model_run: print("Running model...") - if "model_args" in self.context.ctx: - model_docker.sh( - "cd " - + model_dir - + " && " - + script_name - + " " - + self.context.ctx["model_args"], - timeout=None, - ) - else: - model_docker.sh( - "cd " + model_dir + " && " + script_name + " " + info["args"], - timeout=None, - ) + model_docker.sh(run_command, timeout=None) else: print("Skipping model run") - print( - "To run model: " - + "cd " - + model_dir - + " && " - + script_name - + " " - + info["args"] - ) + print(f"To run model: {run_command}") run_details.test_duration = time.time() - test_start_time print("Test Duration: {} seconds".format(run_details.test_duration))