-
Notifications
You must be signed in to change notification settings - Fork 320
Add DirectCLR (#781) #1874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DirectCLR (#781) #1874
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks correct. However I am not sure if this is really the right format to include this model in LightlySSL. I think I would prefer something slightly simpler, without a "model file", i.e. just the loss (using the already existing NTXentLoss)
# directclr_loss.py
class DirectCLRLoss(Module):
def __init__(self, loss_dim: int, temperature, ... ): # all the other NTXentLoss params
self.loss_dim = loss_dim
self.ntxent_loss = NTXentLoss(...)
def forward(x0, x1): # both (B, D)
return self.ntxent_loss(
x0[..., :self.loss_dim], x1[..., :self.loss_dim]
)and then the example basically exactly like https://github.com/lightly-ai/lightly/blob/master/examples/pytorch/simclr.py, just with the new loss function and without the projection head.
If that's too much work I can adjust it accordingly myself. :)
liopeer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! It's looking really nice and many of my comments are really just nitpicking and beautification.
The two things we should really change however before merging: The *args, **kwargs story, and the slightly too comprehensive 🙃 unit tests. After that we're ready to merge.
|
Quick note: the We could then have a follow-up issue implementing the whole DirectCLR method at |
yutong-xiang-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added my review as well.
24f1492 to
c4a7e9e
Compare
|
I have:
Is there a standard way to ensure or even automate that the parameters relating to NTXentLoss always match? This way the defaults can get out of sync pretty easily :( |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #1874 +/- ##
==========================================
+ Coverage 86.07% 86.10% +0.02%
==========================================
Files 167 168 +1
Lines 6966 6979 +13
==========================================
+ Hits 5996 6009 +13
Misses 970 970 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
yutong-xiang-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvements! I left some comments on tiny details to beautify the tests then you're good to go.
Is there a standard way to ensure or even automate that the parameters relating to NTXentLoss always match? This way the defaults can get out of sync pretty easily :(
In this specific case, I don't expect the interface of NTXent loss or DirectCLR loss to be changed. They're well-defined losses and with defaults from the original paper which everyone respects.
811cd71 to
12aacd1
Compare
|
Resolved the temperature parameter and default parameter comments. |
yutong-xiang-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
liopeer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything addressed from my side as well. Amazing work @KylevdLangemheen !
Issue #781 proposes adding DirectCLR to Lightly. In PR #963 a loss was proposed but never merged.
As the paper uses a version of infoNCE that is identical to NTX-ent, we can use our existing implementation of NTX-ent instead. This way, implementing DirectCLR becomes a simple matter of taking the first d values of the embedding and applying the NTX-ent loss to that instead (if my understanding is correct 😄 ).
As DirectCLR does away with a projection head, I did not see an obvious way to implement it under low level module blocks. The proposed implementation is added as a deprecated model, but perhaps it does not even need to get its own model file and can just live as an example.
This PR:
This PR does not:
Let me know if you would prefer to see a variant outside of the deprecated high-level models. I could make a "dummy" projection head which just slices the embedding for example. Also let me know if you want me to include further features :)