Skip to content

Conversation

@TillHae
Copy link
Contributor

@TillHae TillHae commented Dec 25, 2025

Description

This PR addresses an issue where the number of samples and the starting epoch were calculated incorrectly when continuing a training run with a different number of nodes.

  • Added samples as a config parameter. This parameter is always the current cumulative data processed across all training segments.
  • Captured world_size_original in run_train.py before the environment initialization overwrites it. This ensures the model correctly interpretes the context of the run it is continuing from.
  • EMA decay schedules now use total samples processed, preventing "jumps" or resets when scaling.
  • Included Backward Compatibility in TrainLogger to safely initialize the new samples parameter from istep for legacy runs.

Issue Number

Closes #587

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Angie25 and others added 21 commits December 25, 2025 03:42
* Log gradient norms

* Prototype for recording grad norms

* Address review changes + hide behind feature flag

* Final fixes including backward compatibility

* Ruff

* More ruff stuff

* Update to develop, prepare for new experiment series

* forecast config with small decoder

* fixed uv.lock

* test gradient logging on mutli gpus

* Setting o48 as default in era5 config

Committer: Matthias Karlbauer <matthias.karlbauer@ecmwf.int>

On branch mk/develop/fe_experiments
Your branch is ahead of 'origin/mk/develop/fe_experiments' by 57 commits.
  (use "git push" to publish your local commits)

Changes to be committed:
  modified:   config/streams/era5_1deg/era5.yml

* Updated default config to 256 dim latent size

On branch mk/develop/fe_experiments
Your branch is ahead of 'origin/mk/develop/fe_experiments' by 58 commits.
  (use "git push" to publish your local commits)

Changes to be committed:
	modified:   config/default_config.yml

* Update branch to latest develop

* Change epochs from 64 to 32

* LayerNorm replication and analysis tools

* Rename fe_layer_norm_at_layers to fe_layer_norm_after_blocks

* Increase epochs from 32 to 64 and resolve minor bug

* Update default_config back to d2048 on the O96 grid

* Update ERA5 stream to O96 grid

* Resolving bug after merging with develop and updating default_config

* Enable loading old model checkpoints after recent merges

* Update WeatherGenReader with mini-epoch notation

* Minor modifications to latent histogram plotting

* Resolve bug in histogram plotting

* Replace getattr by cf.get

* Change target read-out engine from 1 to 2 layers

* Set aux-info for fe-blocks to none

* fix a plotting bug (ecmwf#1453)

* Update train/val dates, HL=5, fsteps=2, lat-weighting

* removed plotting latent histograms

* modified configs

* removed the eval and train plot configs

* added 00 as minutes

* lint

* added fc config + renamed to fe_impute_latent_noise_std

* lint

* removed parameter renaming for backward compatibility

* removed weight_progression and plot_grad files

* corrected end_date

* using .get()

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Matthias Karlbauer <matthias.karlbauer@ecmwf.int>
Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int>
Co-authored-by: Julian Kuehnert <julian.b.kuehnert@gmail.com>
Co-authored-by: Matthias Karlbauer <mkarlbau@santis-ln002.cscs.ch>
Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com>
…aining, and pass ModelBatch class (ecmwf#1283)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (ecmwf#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (ecmwf#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (ecmwf#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (ecmwf#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (ecmwf#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
* setting fe_num_blocks: 0

* updated def forecast config

* reverting to values 4096/512 for no of samples

* 32->64 epoches

* added the forecast settings

* using relationship: independent

* lint

* remove metrics.py for the PR

* readded metrics.py

* num_register_tokens: 0

* val start date october
* modified ratio plot

* heat-map

* fix cosmetics

* lint

* change config

* add metric name to heatmap

* multiply forecasts by step_hrs

* remove breakpoints

* fix variable order

* lint

* Fix a minor bug on score cards (ecmwf#1492)

* Correct a minor bug regarding score cards

* Linting

---------

Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com>
* add default template

* update instructions
* using new parameter

* using the self. parameter
* Revert "Make missing fstep 0 produce a more meaningfull error (ecmwf#1444)"

This reverts commit e3b29be.

* restore integration test

* Change to step_timedelta in warning

* fix integration test after update of metric logging

---------

Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com>
Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int>
* Fix DataReaderFesom

* Linter

* Remove hardcoded ERA5 name

* Fix iteration over streams

* Unlinter
* new config for multi integration test

* new conf multi-integration test - align with big merge

* increased samples_per_validation

* [1458] fix single integration test

* [1435] fix single-integration test

---------

Co-authored-by: simone99n <simone.norberti@gmail.com>
Co-authored-by: Simone Norberti <63310821+simone99n@users.noreply.github.com>
…cmwf#1507)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (ecmwf#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Healpix cropping simple implementation

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Healpix cropping simple implementation with control over the num_samples and overlap + fixing the num_sample bug

* Fixed lint

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (ecmwf#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (ecmwf#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (ecmwf#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (ecmwf#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

* actually merging develop in properly with merge conflicts resolved

* move imports, move functions

* make functions for different types of cropping rather than if else, remove overlap from select_spatially_contiguous_cells

* restore overlap control in select spatially contiguous cell

* clean up cropping, remove overlap control of crops for now

* make overlap work with source and target masks. working now as random selection. need to think about this if we want to explicitly do overlap of crops.

* restored config somewhat

* lint

* restore config a bit

* remove extra commented out code

* clean up

* remove logging

* invert healpix_cropping masking so aligned with masking and healpix

* lint and updated comments

* remove overlap code, deal with in _get_mask complement, subset etc

* update comments and trues falses masking

* remove legacy argument constraint_keep_mask in the docstring for 2 masking functions

* remove legacy overlap_ratio and overlap from the docstrings in masking

* remove overlap ratio unused arg from example configs for cropping

* make cropping a function, and build shared _prepare_healpix_masking for healpix and healpix_cropping preparation

* rename healpix preparation function

* removed old docstrings and lint

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
* Abstract class for target/aux computation

Implemented Identity class

TODO: implement EMATeacher

* Start implementing the EMA Teacher

The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark

* adding loss calculator base class

* Option for constructing teacher model flexibly

* Extract get batch size util function

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction

* Fix mismatched dtypes in the target computation

It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements

* abstract loss calc structure

* add abstract method to loss calculator base class

* add latent loss class

* update loss calc config and rename files

* restructure loss modules

* add ModelOutput dataclass

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* mv streams_data declaration under if condition

* add weight to loss config, add toy loss class LossPhysicalTwo

* Update Abstract Target class based on needs for SSL losses

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* fixed trainer for multiple terms in losses_all, still need to fix logging

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* fix _log_terminal

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* fix logging

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* initialize loss as torch tensor with grad

* remove level in hist losses dict

* rename loss.py to loss_functions.py

* rename loss.py to loss_functions.py

* return loss with grads seperately to trainer

* Added mode and refactored get_sample_data into separate function.

* modify log names

* add loss_functions.py

* Abstract class for target/aux computation

Implemented Identity class

TODO: implement EMATeacher

* Start implementing the EMA Teacher

The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark

* Option for constructing teacher model flexibly

* rm loss_fcts in default config

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* Extract get batch size util function

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction

* Fix mismatched dtypes in the target computation

It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements

* Lay groundwork for SSL losses

This involves creating stateful classes for each of the losses and the
EMATeacher being able to run additional neural network heads for these
losses.

* Add the SSL Loss Processing classes

* Write part of the TargetProcessing forward

TODO: create the various teacher head modules and run them.
TODO: merge the abstract loss calculator and create the SSL one

* Add latent prediction heads to the Model

After much consideration I decided to add the latent prediction heads to
the Model, because they also need to benefit from exponential moving
average of the weights and this gets unnecessarily cumbersome if they
are outside the Model.

TODO: make JEPA different between student and teacher
TODO: use this new structure in EMATeacher

* Adapt forward function for latent prediction heads

To prevent crazy nesting of model output values we created a ModelOutput
Dataclass (akin to how it is done in huggingface), and we run all the
latent_prediction heads.

* Start piping configs through model, trainer, etc

Will need adapting based on the abstract loss calculator

Currently is awaiting the streams data branch to check piping of data
and configuring this

* adding dinov2 notice

* Draft Student Teacher Loss Calculator

TODO: initialise it and register
TODO: weight the loss
TODO: route the kwargs
TODO: check shapes of tensors

* Use infra provided by Abstract Loss Calc

Completes config option routing, weighting, and registering TODOs

* Run Ruff

* Implemented the first draft of the Cropping feature

* rough first effort producing globaland local views

* update to return 6 tuple from iter in multi-stream-data-sampler, with locals_prepared

* Fix class being in the wrong file

* Ensure data pipes through model and target

This is a DRAFT!

This commit assumes that the data augmentations of the stream_data
objectsee shmh40/dev/global-local will fit into the Batch data class
(trainer.py).

The goal was to ensure all data reaches the LossCalculator.

Large scale todos:
- Pass Metadata about local2global correspondance, etc to the LossCalculator
- Upgrade the Model heads to produce the correct dimensions
- Verify the Data shapes against DINOv2

Smaller todos:
- Ensure teacher params are requires_grad = false
- clean up code

* Wrap latent state into a dataclass

to simply loss calculation later

* Progress on computing the loss on correct dims

Added the new ViewMetadat and ModelBatch dataclasses that will come from
the cropping PR

Added LatentState dataclass to compute the latent heads on the correct
part of the latent state

TODOs:
1. Deal with DINO local and global component
2. Write JEPA loss function in loss.py
3. Test iBOT with actual mask and route student temperature
4. TODOs in the code

* Add views.py and run Ruff

* Close in on completing DINO loss

TODO needs to deal with the tuple part of the DINO loss
TODO understand how the ModelBatch input structure affects the loss
terms

* Revert "rough first effort producing globaland local views"

This reverts commit 3fa0033.

* Lint code

* Fix rebase of loss loss_calculator

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* Test for compute time regressions

* Prepare for merge

* Lint the code

* Lint code

* Lint

* Fix some basic bugs

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* Removing spurious code / things that should be merged later

* Linting

* move build_views_for_stream into masker

* Lint code

* Rename identity TargetAndAux module

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* Make code runnable

* updated configs so code runs. Note default config to be overhauled still

* Draft for model interface

* Make code runnable again

Seems slow again

* Cleaned up and restructured structure. Not working yet with FSDP

* Fixes for FSDP/DDP

* Cleaning up, should be merged when needed

* Fixes to FSDP

* Fix incorrect args for model loading and removing unused code.

* Linting

* Removing old code

* - Fixing inference arg order
- Fixing subtle problem with world_size_original that should be taken from config when available

* Fixing interface of get_target_aux_calculator

* Fixing call to target aux calculator

* Fixes to get_target_aux_calculator

* Remove stale dataclasses

* Fix MAE

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* Prepare for another merge

* remove prints, pdb

* Save state

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* Save state for Seb

Currently re-viving the EMATeacher creation

Memory is an issue, had to hardcode a smaller latent space

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Attemp to make the iBOT loss work

TODO force 1 ibot student view per global view
TODO there is a bug with the mask causing a leaf error in pytorch
TODO remove all the hardcoded reduced latent space

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* Pipe data through all ssl loss fns

TODO iBOT head should output class tokens as well as patch tokens
TODO remove hardcoded assignments, should be based on config
TODO deal with the memory hungriness of it all
TODO carefully inspect for bugs

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Pipe the mask through

* Filter student views for the correct loss

* Change the masking and msdp to fit student-teacher

1. We ensure that for each target view all the student views are generated
2. We ensure that the target views have their mask applied to the input

* Make DINO and iBOT work

TODO: use the target mask to reduce memory

* Prepare for Model PR introducing class & reg token

Thus, right now it breaks. The big question is memory!

* Integrate the class and register token PR

Done manually because I couldn't figure out how to merge from a fork

* Fix iBOT loss with correct PredHead

Limitation: iBOT loss needs num_class_tokens to be 1

* Fix JEPA + Lint code

* Fix DDP

It had unused parameters from the decoders these had to be removed

* Running this code + config for JEPA with DDP

* Ran JEPA DDP plot with this

* Fix FSDP error

Potentially a slow down, but I don't understand FSDP well enough for a better fix

* Fix conig

* Fix validation

* Stuck on error taking a break

* hot fix to empty tokens_c in encoder when looping over chunks

* Revert "hot fix to empty tokens_c in encoder when looping over chunks"

This reverts commit e4519d8.

* hot fix for local assimilation empty tokens_c

* Add class tokens being variable + Fix bugs

* Push remaining changes to default config

* deepcopy configs so we do not pop weight and lose it for inference

* fixed bug in inference with +2 in forecast steps range

* add required import to trainer

* Update uv.lock

* Linting

* Record fstep latent states

* added two configs, jepa and ibot/dino. Note these configs still try to merge/overwrite the default config with the --config flag

* Addres comments from PR review

* Prepare SSL losses for logging

Currently nothing happens in the terminal but I don't know why that is

* Lint

* Address PR comments+ upstream changes

* Appease the hidden linter

* Rename ssl_losses_utils

* Add the untracked file

* Removing spurious character

---------

Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Sebastian Hickman <seb.hickman@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Sophie Xhonneux <sxhonneu@santis-ln001.cscs.ch>
Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
@TillHae TillHae changed the title Thauer/develop/issue 587 Fix sample calculation jumping on run continuation Dec 25, 2025
self.cf.istep
/ (min(len_per_rank, cf.samples_per_mini_epoch) * self.world_size_original)
)
# recover mini_epoch from cumulative samples (invariant to world size change)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not store the mini_epoch in the config state? that sounds dangerous, it should be the source of truth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a separate issue for that #1543

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a separate issue for that #1543

We removed mini-epoch from the config a while ago, and for good reasons. We cannot keep it consistent with istep. We agreed (there's a design doc) that istep should be the single source of truth and all functionality should entirely rely on this. Mini-epoch should just be a convenience that is used say for terminal output to orient the user.

Copy link
Collaborator

@tjhunter tjhunter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TillHae thanks for looking into this annoying issue. Here are some initial comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Wrong istep when continue training with different nb of nodes

10 participants