-
Notifications
You must be signed in to change notification settings - Fork 49
Adding an integration test for JEPA #1519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding an integration test for JEPA #1519
Conversation
| @@ -0,0 +1,102 @@ | |||
| """ | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much code duplication is here with the other integration tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite a bit, a refactoring would be beneficial. Tim said we should initially copy-paste the tests and refactor later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you open an issue so that we track the refactoring
| uv run --offline pytest ./integration_tests/small1_test.py --verbose -s | ||
| ) | ||
| ;; | ||
| integration-test-jepa) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we all add an integration-test-all that runs all integration tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes please
src/weathergen/model/model.py
Outdated
| stream_name = self.stream_names[i_obs] | ||
| if cf.training_config.losses["LossPhysical"].weight > 0.0: | ||
| for i_obs, si in enumerate(cf.streams): | ||
| stream_name = self.stream_names[i_obs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i_obs -> i_stream
src/weathergen/model/model.py
Outdated
| if "LossPhysical" in loss_calculators: | ||
| for i_obs, si in enumerate(cf.streams): | ||
| stream_name = self.stream_names[i_obs] | ||
| if cf.training_config.losses["LossPhysical"].weight > 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are just due to incomplete visibility on github?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added these weight > 0.0 checks to switch off the LossPhysical related parts of the code, otherwise it tries to use both losses because the configs are appending.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will break for latent SSL training without LossPhysical. So please use cf.training_config.get( "LossPhysical")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still confused by the large number of changes in model.py where I cannot see what actually changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching, switched to ``.get()`
| losses_all[calculator.name] = loss_values.losses_all | ||
| losses_all[calculator.name]["loss_avg"] = loss_values.loss | ||
| stddev_all[calculator.name] = loss_values.stddev_all | ||
| if weight > 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes appear just due to github?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
|
@clessig have a look please |
| stream_name=f"embed_target_coords_{stream_name}", | ||
| ) | ||
| else: | ||
| assert False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is not informative. Can you either raise an exception with a full message or, if you prefer asserts, do :
assert etc["net"] in ["mlp", "linear"], etc["net"]
if etc["net"] == "linear":
...then the value that is causing problem is clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unrelated to the PR. We can do this as part of #1537
|
To be merged after @clessig 's work on the config. Looking forward to having a test for SSL |
| @@ -0,0 +1,102 @@ | |||
| """ | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you open an issue so that we track the refactoring
Description
The new SSL modes need integration tests, this is one for JEPA.
Note that I had other PRs open, but those got complicated with other branches merged, opened a clean one again.
Issue Number
Closes #1516
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60