joint_labels = torch.cat([labels + P.n_classes * i for i in range(4)], dim=0) I do not understand what is the meaning of this code.