diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py index e37a3ea..bfea5d8 100644 --- a/trellis2/models/sc_vaes/sparse_unet_vae.py +++ b/trellis2/models/sc_vaes/sparse_unet_vae.py @@ -681,7 +681,9 @@ def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor] h = block(h) # More frequent cache clearing in low_vram mode to reclaim memory between blocks if self.low_vram: + torch.cuda.synchronize() torch.cuda.empty_cache() + torch.cuda.synchronize() torch.cuda.empty_cache() if self.low_vram: def fused_finalize(t): diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py index 3f69059..26542f2 100644 --- a/trellis2/pipelines/trellis2_image_to_3d.py +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -1166,11 +1166,13 @@ def run( if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() pbar.update(1) if return_latent: @@ -1518,11 +1520,13 @@ def run_cascade( if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() pbar.update(1) @@ -1728,11 +1732,13 @@ def run_multiview( if pbar is not None: pbar.update(1) + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() if pbar is not None: @@ -2479,8 +2485,10 @@ def texture_mesh( if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() pbr_voxel = self.decode_tex_slat(tex_slat) + torch.cuda.synchronize() torch.cuda.empty_cache() out_mesh, baseColorTexture, metallicRoughnessTexture = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size, texture_alpha_mode, double_side_material, bake_on_vertices, use_custom_normals, mesh_cluster_threshold_cone_half_angle_rad) @@ -2574,8 +2582,10 @@ def texture_mesh_multiview( if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() pbr_voxel = self.decode_tex_slat(tex_slat) + torch.cuda.synchronize() torch.cuda.empty_cache() out_mesh, baseColorTexture, metallicRoughnessTexture = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size, texture_alpha_mode, double_side_material, bake_on_vertices, use_custom_normals, mesh_cluster_threshold_cone_half_angle_rad) @@ -2812,11 +2822,13 @@ def refine_mesh( if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() + torch.cuda.synchronize() torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) + torch.cuda.synchronize() torch.cuda.empty_cache() if return_latent: