I found that although this is a single step diffusion, the memory consumption during training is still quite large. Considering that I only have RTX3090, I will record some of my settings, hoping to help those in need.
1、--train_batch_size set from 4 to 2.
2、also can set --gradient_checkpointing for less memory
the result can be reproduced !