-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Hi, @CircleRadon. Thank you for your great work. I am not clear about the weight of aproloss and its implementations.
According to the issue 3, the implementation of aproloss is:
class AproLoss(nn.Module):
def __init__(self, ignore_index=255):
super().__init__()
# partial cross entropy
self.partialCE = nn.CrossEntropyLoss(ignore_index=ignore_index)
# apro
self.global_apro = Global_APro()
self.local_apro = Local_APro(kernel_size=5, zeta_s=0.15) #set kernel_size and zeta_s
self.mst = MinimumSpanningTree(Global_APro.norm2_distance)
# pca n_component
# self.q = 1
self.ignore_index = ignore_index
def forward(self, x, y_hat, y):
# x: B, C, H, W
# y_hat B, classes, H, W
# partial cross entropy
partial = self.partialCE(y_hat, y)
# compute PCA
# B, 1, H, W
# pca_imgs = self.compute_pca(x)
# compute image tree
# I think directly using x is also fine
img_mst_tree = self.mst(x)
# img_mst_tree = self.mst(pca_imgs)
# y: B, H, W
# y = y.float()
y_hat = torch.softmax(y_hat, dim=1) # convert to probability [0,1]
# psuedo label for global info
# using low level feature
soft_pseudo = self.global_apro(y_hat, x, img_mst_tree, zeta_g=0.001)
# using deep feature
soft_pseudo = self.global_apro(soft_pseudo, y_hat, img_mst_tree, zeta_g=0.05)
# unlabelled region only
unlabelled_regions = (y.unsqueeze(1) == self.ignore_index)
# compute difference between generated psuedo labels and predicted one
loss_global_term = torch.abs(soft_pseudo-y_hat) * unlabelled_regions
# normalize the loss
n_regions = unlabelled_regions.sum().clamp(min=1)
loss_global = loss_global_term.sum() / n_regions
# local term
soft_pseudo = self.local_apro(pca_imgs, y_hat)
loss_local_term = torch.abs(y_hat - soft_pseudo) * unlabelled_regions
loss_local_term = loss_local_term.sum() / unlabelled_regions.sum().clamp(min=1)
loss_local = loss_local_term
return partial + loss_global + loss_local
I have several questions:
- How to set the weight of partial cross entropy and global/local apro loss?
- For global apro, the deep feature is directly set to
y_hat. Is this the defualt setting in your paper? Why it should be set as the last feature map from the segmentation network?
Thank you in advance.
Metadata
Metadata
Assignees
Labels
No labels