Skip to content

Fix Blackwell (sm_120) illegal memory access via CUDA stream handling#29

Open
cuzelac wants to merge 1 commit intoJeffreyXiang:mainfrom
cuzelac:fix/blackwell-stream-sync-upstream
Open

Fix Blackwell (sm_120) illegal memory access via CUDA stream handling#29
cuzelac wants to merge 1 commit intoJeffreyXiang:mainfrom
cuzelac:fix/blackwell-stream-sync-upstream

Conversation

@cuzelac
Copy link

@cuzelac cuzelac commented Mar 8, 2026

Summary

CuMesh causes CUDA error: illegal memory access on NVIDIA Blackwell GPUs (RTX 5090, sm_120) when used with PyTorch. Setting CUDA_LAUNCH_BLOCKING=1 masks the issue (by serializing all GPU work), confirming a stream synchronization problem.

Root cause

All CUDA kernel launches, CUB operations, and cudaMemcpy calls in CuMesh use the default CUDA stream (stream 0). PyTorch allocates memory on cudaStreamNonBlocking streams, which have no implicit synchronization with stream 0. This means CuMesh kernels can execute before PyTorch's memory operations complete, or PyTorch can free/reuse memory while CuMesh kernels are still running on stream 0.

This has always been undefined behavior per the CUDA programming model, but only manifests reliably on Blackwell GPUs — likely due to changes in the memory subsystem or stream scheduling.

Fix

  • Retrieve PyTorch's current CUDA stream via at::cuda::getCurrentCUDAStream() in every function
  • Pass the stream to all kernel launches (<<<blocks, threads, 0, stream>>>)
  • Pass the stream to all CUB calls (e.g., cub::DeviceScan::ExclusiveSum(..., stream))
  • Use cudaMemcpyAsync with the stream instead of synchronous cudaMemcpy
  • Add cudaStreamSynchronize(stream) before all cudaFree calls, since cudaFree on non-blocking streams may free memory before async kernels finish using it
  • Initialize scratch pointers to nullptr and conditionally free them at the end of functions (after sync)

This follows the same pattern used by nvdiffrast and cubvh.

Note: An equivalent PR has been submitted to visualbruno/CuMesh which includes additional functions not present in this repo.

Testing

  • Tested on RTX 5090 (sm_120), PyTorch 2.10.0+cu130, CUDA 13.0, Windows
  • Full Trellis2 image-to-3D pipeline including mesh generation, refinement, and texturing
  • Multiple successful runs with zero crashes and no CUDA_LAUNCH_BLOCKING=1
  • Verified the fix is necessary by reverting and reproducing the crash

Files changed

  • src/remesh/svox2vert.cu — sparse voxel grid vertex extraction
  • src/connectivity.cu — mesh connectivity computation
  • src/clean_up.cu — mesh cleanup operations
  • src/atlas.cu — texture atlas generation
  • src/shared.h — template compress_ids used throughout
  • src/utils.hBuffer::free() stream safety
  • src/geometry.cu, src/hash/hash.cu, src/remesh/simple_dual_contour.cu, src/simplify.cu — stream passing for kernel launches and CUB calls

…andling

PyTorch uses cudaStreamNonBlocking streams, but CuMesh launched all
kernels on the default stream (stream 0) via bare <<<blocks, threads>>>
syntax. On Blackwell GPUs (RTX 5090, CUDA 13.0), this stream mismatch
causes "illegal memory access" errors.

Changes:
- Add current_stream() helper wrapping at::cuda::getCurrentCUDAStream()
- Pass PyTorch's current CUDA stream to all kernel launches (<<<..., 0, stream>>>),
  CUB calls, cudaMemcpyAsync, and cudaMemsetAsync across all 10 source files
- Fix cudaFree race conditions: on non-blocking streams, cudaFree may
  return memory to the pool before async kernels finish using it. Added
  cudaStreamSynchronize(stream) before cudaFree where the freed memory
  was recently used by async work on the stream
- Convert cudaMemcpy to cudaMemcpyAsync with stream, adding
  cudaStreamSynchronize where the CPU needs to read the result
- Replace timing-related cudaDeviceSynchronize calls in simplify.cu
  with cudaStreamSynchronize(stream)

Tested on RTX 5090 32GB, PyTorch 2.10.0+cu130, CUDA 13.0, Windows.
Full Trellis2 reconstruction pipeline completes without errors.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@cuzelac
Copy link
Author

cuzelac commented Mar 8, 2026

Also filed as visualbruno/CuMesh#2 for the visualbruno fork, which is used by ComfyUI-Trellis2.

@cuzelac
Copy link
Author

cuzelac commented Mar 8, 2026

This may also fix #27 (fragmented mesh output on Blackwell) — same root cause (stream 0 race condition), different symptom.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant