From 413cb597c99fb0134f269a742ba55e546e2f730e Mon Sep 17 00:00:00 2001 From: Chris Uzelac Date: Sat, 7 Mar 2026 23:01:04 -0800 Subject: [PATCH] Add torch.cuda.synchronize() before every empty_cache() call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch.cuda.empty_cache() does not synchronize — it can free memory that in-flight async kernels are still using, causing "illegal memory access" errors on high-concurrency GPUs (e.g. RTX 5090 / Blackwell). Add synchronize() before each empty_cache() to ensure all GPU work completes before memory is released. Co-Authored-By: Claude Opus 4.6 --- trellis2/models/sc_vaes/sparse_unet_vae.py | 2 ++ trellis2/pipelines/trellis2_image_to_3d.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py index e37a3ea..bfea5d8 100644 --- a/trellis2/models/sc_vaes/sparse_unet_vae.py +++ b/trellis2/models/sc_vaes/sparse_unet_vae.py @@ -681,7 +681,9 @@ def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor] h = block(h) # More frequent cache clearing in low_vram mode to reclaim memory between blocks if self.low_vram: + torch.cuda.synchronize() torch.cuda.empty_cache() + torch.cuda.synchronize() torch.cuda.empty_cache() if self.low_vram: def fused_finalize(t): diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py index 3f69059..26542f2 100644 --- a/trellis2/pipelines/trellis2_image_to_3d.py +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -1166,11 +1166,13 @@ def run( if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() pbar.update(1) if return_latent: @@ -1518,11 +1520,13 @@ def run_cascade( if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() pbar.update(1) @@ -1728,11 +1732,13 @@ def run_multiview( if pbar is not None: pbar.update(1) + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() if pbar is not None: @@ -2479,8 +2485,10 @@ def texture_mesh( if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() pbr_voxel = self.decode_tex_slat(tex_slat) + torch.cuda.synchronize() torch.cuda.empty_cache() out_mesh, baseColorTexture, metallicRoughnessTexture = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size, texture_alpha_mode, double_side_material, bake_on_vertices, use_custom_normals, mesh_cluster_threshold_cone_half_angle_rad) @@ -2574,8 +2582,10 @@ def texture_mesh_multiview( if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() pbr_voxel = self.decode_tex_slat(tex_slat) + torch.cuda.synchronize() torch.cuda.empty_cache() out_mesh, baseColorTexture, metallicRoughnessTexture = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size, texture_alpha_mode, double_side_material, bake_on_vertices, use_custom_normals, mesh_cluster_threshold_cone_half_angle_rad) @@ -2812,11 +2822,13 @@ def refine_mesh( if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() if return_latent: