Code for NeurIPS 2023 paper Rewiring Neurons in Non-Stationary Environments.
- Make sure you have PyTorch and JAX installed with CUDA support.
- Install SaLinA and Continual World following their instructions. Note that the latter only supports MuJoCo version 2.0.
- Install additional packages via
pip install -r requirements.txt.
Simply run the file run.py with the desired config available in configs:
python run.py -cn=METHOD scenario=SCENARIO OPTIONAL_CONFIGSExpand
We present 9 different CRL methods all built on top of soft-actor critic algorithm. To try them, just add the flag -cn=my_method on the command line. You can find the hyperparameters in configs:
rewire: our method in "Rewiring Neurons in Non-Stationary Environments".ft_1: Fine-tune a single policy during the whole training.sac_n: Fine-tune and save the policy at the end of the task. Start with a randomized policy when encountering a new task.ft_n: Fine-tune and save the policy at the end of the task. Clone the last policy when encountering a new task.ft_l2: Fine-tune a single policy during the whole training with a regularization cost (a simpler EWC method).ewc: see the paper Overcoming catastrophic forgetting in neural networks.pnn: see the paper Progressive Neural Networks.packnet: see the paper PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning.csp: see the paper Building a Subspace of Policies for Scalable Continual Learning.
Expand
We integrate 9 CRL scenarios over 3 different Brax domains and 2 scenarios of the Continual World domain. To try them, just add the flag scenario=... on the command line:
halfcheetah/forgetting: 8 tasks - 1M samples for each task.halfcheetah/transfer: 8 tasks - 1M samples for each task.halfcheetah/robustness: 8 tasks - 1M samples for each task.halfcheetah/compositionality: 8 tasks - 1M samples for each task.ant/forgetting: 8 tasks - 1M samples for each task.ant/transfer: 8 tasks - 1M samples for each task.ant/robustness: 8 tasks - 1M samples for each task.ant/compositionality: 8 tasks - 1M samples for each task.humanoid/hard: 4 tasks - 2M samples for each task.continual_world/t1-t8: 8 triplets of 3 tasks - 1M samples for each task.continual_world/cw10: 10 tasks - 1M samples for each task.
Expand
The core.py file contains the building blocks of this framework. Each experiment consists in running a Framework over a Scenario, i.e. a sequence of train and test Task. The models are learning procedures that use CRL agents to interact with the tasks and learn from them through one or multiple algorithms.
- frameworks contains generic learning procedures (e.g. using only one algorithm, or adding a regularization method in the end).
- scenarios contains CRL scenarios i.e sequence of train and test tasks.
- algorithms contains different RL / CL algorithms (e.g. SAC, or EWC).
- agents contains CRL agents (e.g. PackNet, CSP, or Rewire).
- configs contains the configs files of benchmarked methods/scenarios.
Our implementation is based on:
