Adds functionality to populate torch generator using torch.thread_safe_generator#9371
Adds functionality to populate torch generator using torch.thread_safe_generator#9371divyanshk wants to merge 4 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9371
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Pending, 2 Unrelated FailuresAs of commit 6277f11 with merge base 48956e0 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1e0225e to
b15da1c
Compare
NicolasHug
left a comment
There was a problem hiding this comment.
Thanks for the PR @divyanshk . I think the changes look reasonable.
One thing I'm wondering is how does this affect the multiprocess-based dataloaders? Currently, since TV is using the global torch RNG, that global generator will be seeded by torch using a different seed for each process/worker. This is the correct behavior since we want each worker to have a different RNG.
Is that behavior preserved now that we're using torch.thread_safe_generator()?
It'd be good to have tests ensure that's the case (both for multiprocess and multithreaded cases).
b15da1c to
18bdef3
Compare
|
The multiprocessing case remains unchanged because torch.thread_safe_generator will return None for multiprocessing use-case. So for MP, there is no change. Earlier the torch.rand functions received None for generator arg, and now they would get the same. Also added a test case where I confirm the expected behavior for multiprocessing. |
c416fad to
e7da958
Compare
e7da958 to
6277f11
Compare
Added thread-safe random number generation to all V2 torchvision random transforms to prevent race conditions when using DataLoader with thread-based workers (worker_method='thread').
This is based on
torch.thread_safe_generatorwhich returns dataloader thread-worker specific RNG or None otherwise.