From 13b13ac1c75e339e1501d3692de4d44a6414e5bb Mon Sep 17 00:00:00 2001 From: Sam Toyer Date: Thu, 20 Aug 2020 16:53:18 -0700 Subject: [PATCH 1/5] Initial tool settings Move all deps to setup.py Fewer torch conflicts Fix version conflict Hopefully fix CircleCI Forgotten pytest plugins --- .circleci/config.yml | 3 --- reformat.sh | 14 +++++++++++++ requirements.txt | 21 -------------------- setup.cfg | 17 ++++++++++++++-- setup.py | 47 ++++++++++++++++++++++++++++++++++++-------- 5 files changed, 68 insertions(+), 34 deletions(-) create mode 100755 reformat.sh delete mode 100644 requirements.txt diff --git a/.circleci/config.yml b/.circleci/config.yml index 7bffdfb2..6f817070 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -6,9 +6,6 @@ jobs: - image: humancompatibleai/il-representations:2020.08.03-r3 steps: - checkout - - run: - command: pip install -r requirements.txt - name: Install dependencies - run: command: curl -so ~/.mujoco/mjkey.txt "${MUJOCO_KEY}" name: Set up MuJoCo diff --git a/reformat.sh b/reformat.sh new file mode 100755 index 00000000..6c363ced --- /dev/null +++ b/reformat.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# Reformats imports and source code so that you don't have to + +set -xe + +SRC_FILES=(src/ tests/ setup.py) + +echo "Sorting imports" +isort -r ${SRC_FILES[@]} +echo "Removing unused imports" +autoflake --in-place --expand-star-imports --remove-all-unused-imports -r ${SRC_FILES[@]} +echo "Reformatting source code" +yapf -ir ${SRC_FILES[@]} diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 1bf7ec58..00000000 --- a/requirements.txt +++ /dev/null @@ -1,21 +0,0 @@ -numpy~=1.19.0 -gym[atari]~=0.17 -sacred~=0.8.1 -torch~=1.5.1 -opencv-python~=4.3.0.36 -torchvision~=0.6.1 -pyyaml~=5.3.1 -sacred~=0.8.1 -tensorboard~=2.2.0 -pytest~=5.4.3 - -# imitation needs special branch as of 2020-08-17 -git+git://github.com/HumanCompatibleAI/imitation@image-env-changes#egg=imitation -git+https://github.com/HumanCompatibleAI/stable-baselines3.git@imitation#egg=stable-baselines3 - -# environments -# (MAGICAL currently requires an old version of Pyglet) -pyglet==1.3.* -git+https://github.com/qxcv/magical@master#egg=magical -dm_control~=0.0.319497192 -git+git://github.com/denisyarats/dmc2gym@6e34d8acf18e92f0ea0a38ecee9564bdf2549076#egg=dmc2gym diff --git a/setup.cfg b/setup.cfg index 498ce434..e1ff1647 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,22 @@ [isort] -line_length=79 +line_length=100 known_first_party=il_representations default_section=THIRDPARTY multi_line_output=0 force_sort_within_sections=True +[yapf] +based_on_style=pep8 +column_limit=100 + +[flake8] +max-line-length=100 +ignore=E266,E261 + [tool:pytest] -testpaths=tests \ No newline at end of file +# adding all these to testpaths is necessary to make flake8 and isort run on +# everything +testpaths= + tests/ + src/ +addopts=--isort --flake8 diff --git a/setup.py b/setup.py index 6a9c28b2..bc9c01ce 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ from setuptools import find_packages, setup - setup( name="il-representations", version="0.0.1", @@ -9,11 +8,43 @@ python_requires=">=3.7.0", packages=find_packages("src"), package_dir={"": "src"}, - # FIXME(sam): move from requirements.txt to setup.py once merge is done - install_requires=[], - # FIXME(sam): keeping this as reminder to add all experiment scripts as - # console_scripts - # entry_points={ - # "console_scripts": [], - # }, + install_requires=[ + "numpy~=1.19.0", + "gym[atari]==0.17.*", + "sacred~=0.8.1", + "torch==1.6.*", + "torchvision==0.7.*", + "opencv-python~=4.3.0.36", + "pyyaml~=5.3.1", + "sacred~=0.8.1", + "tensorboard~=2.2.0", + + # testing/dev utils + "pytest~=5.4.3", + "isort~=5.0", + "yapf~=0.30.0", + "flake8~=3.8.3", + "autoflake~=1.3.1", + "pytest-flake8~=1.0.6", + "pytest-isort~=1.1.0", + + # imitation needs special branch as of 2020-08-20 + ("imitation @ git+git://github.com/HumanCompatibleAI/imitation" + "@image-env-changes#egg=imitation"), + ("stable_baselines3 @ git+https://github.com/HumanCompatibleAI/stable-baselines3.git" + "@imitation#egg=stable-baselines3"), + + # environments + "magical @ git+https://github.com/qxcv/magical@master", + "dm_control~=0.0.319497192", + ("dmc2gym @ git+git://github.com/denisyarats/dmc2gym" + "@6e34d8acf18e92f0ea0a38ecee9564bdf2549076"), + ], + entry_points={ + "console_scripts": [ + "run_rep_learner=il_representations.scripts.run_rep_learner:main", + "il_train=il_representations.scripts.il_train:main", + "il_test=il_representations.scripts.il_test:main", + ], + }, ) From 314d4c934718d37b30f3a01b326246f12b6fc55c Mon Sep 17 00:00:00 2001 From: Sam Toyer Date: Wed, 7 Oct 2020 15:29:23 -0700 Subject: [PATCH 2/5] Manual formatting fixes --- src/il_representations/envs/config.py | 3 +++ src/il_representations/scripts/il_test.py | 5 ++++- src/il_representations/scripts/il_train.py | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/il_representations/envs/config.py b/src/il_representations/envs/config.py index dbd91a23..2ddfa0f2 100644 --- a/src/il_representations/envs/config.py +++ b/src/il_representations/envs/config.py @@ -80,3 +80,6 @@ def bench_defaults(): 'PongNoFrameskip-v4': "data/atari/PongNoFrameskip-v4_rollouts_500_ts_100_traj.npz", } + + _ = locals() + del _ diff --git a/src/il_representations/scripts/il_test.py b/src/il_representations/scripts/il_test.py index 8693d95c..7cef0877 100644 --- a/src/il_representations/scripts/il_test.py +++ b/src/il_representations/scripts/il_test.py @@ -30,9 +30,12 @@ def default_config(): # being tested run_id = 'test' + _ = locals() + del _ + @il_test_ex.main -def test(policy_path, benchmark, seed, n_rollouts, device_name, run_id): +def run(policy_path, benchmark, seed, n_rollouts, device_name, run_id): set_global_seeds(seed) # FIXME(sam): this is not idiomatic way to do logging (as in il_train.py) logging.basicConfig(level=logging.INFO) diff --git a/src/il_representations/scripts/il_train.py b/src/il_representations/scripts/il_train.py index 267ea5f8..de512cef 100644 --- a/src/il_representations/scripts/il_train.py +++ b/src/il_representations/scripts/il_train.py @@ -31,6 +31,9 @@ def bc_defaults(): n_epochs = 250 # noqa: F841 augs = 'rotate,translate,noise' # noqa: F841 + _ = locals() + del _ + gail_ingredient = Ingredient('gail') @@ -60,6 +63,9 @@ def gail_defaults(): ppo_ent = 1e-5 # noqa: F841 ppo_adv_clip = 0.05 # noqa: F841 + _ = locals() + del _ + il_train_ex = Experiment('il_train', ingredients=[ benchmark_ingredient, bc_ingredient, gail_ingredient, @@ -88,6 +94,9 @@ def default_config(): n_envs=16, ) + _ = locals() + del _ + def make_policy(observation_space, action_space, encoder_or_path, lr_schedule=None): # TODO(sam): this should be unified with the representation learning code From 50dd00b31022a7f971d977ef44fde279b210d83f Mon Sep 17 00:00:00 2001 From: Sam Toyer Date: Wed, 7 Oct 2020 15:42:16 -0700 Subject: [PATCH 3/5] More tooling fixes --- reformat.sh | 17 +++++++++++------ setup.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/reformat.sh b/reformat.sh index 6c363ced..dc2ff54c 100755 --- a/reformat.sh +++ b/reformat.sh @@ -6,9 +6,14 @@ set -xe SRC_FILES=(src/ tests/ setup.py) -echo "Sorting imports" -isort -r ${SRC_FILES[@]} -echo "Removing unused imports" -autoflake --in-place --expand-star-imports --remove-all-unused-imports -r ${SRC_FILES[@]} -echo "Reformatting source code" -yapf -ir ${SRC_FILES[@]} +# sometimes we need a couple of runs to get to a setting that all the tools are +# happy with +n_runs=2 +for run in seq 1 $n_runs; do + echo "Reformatting source code (run $run/$n_runs)" + yapf -ir ${SRC_FILES[@]} + echo "Sorting imports (repeat $run/$n_runs)" + isort ${SRC_FILES[@]} + echo "Removing unused imports (run $run/$n_runs)" + autoflake --in-place --expand-star-imports --remove-all-unused-imports -r ${SRC_FILES[@]} +done diff --git a/setup.py b/setup.py index 55154c12..2dbe72f5 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "@imitation#egg=stable-baselines3"), # environments - "magical @ git+https://github.com/qxcv/magical@master", + "magical @ git+https://github.com/qxcv/magical@pyglet1.5", "dm_control~=0.0.319497192", ("dmc2gym @ git+git://github.com/denisyarats/dmc2gym" "@6e34d8acf18e92f0ea0a38ecee9564bdf2549076"), From f0907c5d3d22ccd2d72960164e360486ca4c372d Mon Sep 17 00:00:00 2001 From: Sam Toyer Date: Wed, 7 Oct 2020 16:00:50 -0700 Subject: [PATCH 4/5] Huge batch of auto-formatting fixes --- reformat.sh | 2 +- setup.cfg | 4 +- src/il_representations/algos/__init__.py | 109 ++++++---- src/il_representations/algos/augmenters.py | 15 +- src/il_representations/algos/base_learner.py | 7 +- .../algos/batch_extenders.py | 33 +-- src/il_representations/algos/decoders.py | 166 +++++++++------ src/il_representations/algos/encoders.py | 195 +++++++++++------- src/il_representations/algos/losses.py | 143 ++++++++----- src/il_representations/algos/optimizers.py | 22 +- .../algos/pair_constructors.py | 69 ++++--- .../algos/representation_learner.py | 135 +++++++----- src/il_representations/algos/utils.py | 54 ++--- src/il_representations/data.py | 7 +- src/il_representations/envs/atari_envs.py | 11 +- src/il_representations/envs/auto.py | 17 +- src/il_representations/envs/config.py | 6 +- .../envs/dm_control_envs.py | 62 ++---- src/il_representations/envs/magical_envs.py | 41 ++-- src/il_representations/il/disc_rew_nets.py | 8 +- src/il_representations/policy_interfacing.py | 24 ++- src/il_representations/scripts/il_test.py | 25 +-- src/il_representations/scripts/il_train.py | 26 +-- .../scripts/run_rep_learner.py | 45 ++-- .../test_support/configuration.py | 12 +- tests/conftest.py | 3 +- tests/test_base_algos.py | 33 +-- tests/test_il_train_test.py | 29 ++- tests/test_reload_policy.py | 8 +- 29 files changed, 746 insertions(+), 565 deletions(-) diff --git a/reformat.sh b/reformat.sh index dc2ff54c..a2c5fed0 100755 --- a/reformat.sh +++ b/reformat.sh @@ -9,7 +9,7 @@ SRC_FILES=(src/ tests/ setup.py) # sometimes we need a couple of runs to get to a setting that all the tools are # happy with n_runs=2 -for run in seq 1 $n_runs; do +for run in $(seq 1 $n_runs); do echo "Reformatting source code (run $run/$n_runs)" yapf -ir ${SRC_FILES[@]} echo "Sorting imports (repeat $run/$n_runs)" diff --git a/setup.cfg b/setup.cfg index 1712e8f9..537bf6ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,7 +11,7 @@ column_limit=100 [flake8] max-line-length=100 -ignore=E266,E261 +ignore=E266,E261,W504 [tool:pytest] # adding all these to testpaths is necessary to make flake8 and isort run on @@ -23,4 +23,4 @@ addopts=--isort --flake8 filterwarnings= ignore:.*importing the ABCs from 'collections' instead of from 'collections.abc'.*:DeprecationWarning ignore:.*Box bound precision lowered by casting to float32.*:UserWarning - ignore:.*The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors.*:UserWarning \ No newline at end of file + ignore:.*The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors.*:UserWarning diff --git a/src/il_representations/algos/__init__.py b/src/il_representations/algos/__init__.py index 4a4711b5..38f12438 100644 --- a/src/il_representations/algos/__init__.py +++ b/src/il_representations/algos/__init__.py @@ -1,13 +1,19 @@ -from il_representations.algos.representation_learner import RepresentationLearner, DEFAULT_HARDCODED_PARAMS -from il_representations.algos.encoders import MomentumEncoder, InverseDynamicsEncoder, DynamicsEncoder, RecurrentEncoder, StochasticEncoder, DeterministicEncoder -from il_representations.algos.decoders import ProjectionHead, NoOp, MomentumProjectionHead, BYOLProjectionHead, ActionConditionedVectorDecoder, TargetProjection -from il_representations.algos.losses import SymmetricContrastiveLoss, AsymmetricContrastiveLoss, MSELoss, CEBLoss, \ - QueueAsymmetricContrastiveLoss, BatchAsymmetricContrastiveLoss - -from il_representations.algos.augmenters import AugmentContextAndTarget, AugmentContextOnly, NoAugmentation -from il_representations.algos.pair_constructors import IdentityPairConstructor, TemporalOffsetPairConstructor +from il_representations.algos.augmenters import (AugmentContextAndTarget, AugmentContextOnly, + NoAugmentation) from il_representations.algos.batch_extenders import QueueBatchExtender -from il_representations.algos.optimizers import LARS +from il_representations.algos.decoders import (ActionConditionedVectorDecoder, BYOLProjectionHead, + MomentumProjectionHead, NoOp, ProjectionHead, + TargetProjection) +from il_representations.algos.encoders import (DeterministicEncoder, DynamicsEncoder, + InverseDynamicsEncoder, MomentumEncoder, + RecurrentEncoder, StochasticEncoder) +from il_representations.algos.losses import (BatchAsymmetricContrastiveLoss, CEBLoss, MSELoss, + QueueAsymmetricContrastiveLoss, + SymmetricContrastiveLoss) +from il_representations.algos.pair_constructors import (IdentityPairConstructor, + TemporalOffsetPairConstructor) +from il_representations.algos.representation_learner import (DEFAULT_HARDCODED_PARAMS, + RepresentationLearner) class SimCLR(RepresentationLearner): @@ -15,10 +21,11 @@ class SimCLR(RepresentationLearner): Implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations https://arxiv.org/abs/2002.05709 - This method works by using a contrastive loss to push together representations of two differently-augmented - versions of the same image. In particular, it uses a symmetric contrastive loss, which compares the - (target, context) similarity against similarity of context with all other targets, and also similarity - of target with all other contexts. + This method works by using a contrastive loss to push together + representations of two differently-augmented versions of the same image. In + particular, it uses a symmetric contrastive loss, which compares the + (target, context) similarity against similarity of context with all other + targets, and also similarity of target with all other contexts. """ def __init__(self, env, log_dir, **kwargs): kwargs = self.validate_and_update_kwargs(kwargs) @@ -39,7 +46,8 @@ def __init__(self, env, log_dir, temporal_offset=1, **kwargs): Implementation of a non-recurrent version of CPC: Contrastive Predictive Coding https://arxiv.org/abs/1807.03748 - By default, augments only the context, but can be modified to augment both context and target. + By default, augments only the context, but can be modified to augment + both context and target. """ kwargs_updates = {'target_pair_constructor_kwargs': {'temporal_offset': temporal_offset}} kwargs = self.validate_and_update_kwargs(kwargs, kwargs_updates=kwargs_updates) @@ -58,8 +66,9 @@ class RecurrentCPC(RepresentationLearner): Implementation of a recurrent version of CPC: Contrastive Predictive Coding https://arxiv.org/abs/1807.03748 - The encoder first encodes individual frames for both context and target, and then, for the context, - builds up a recurrent representation of all prior frames in the same trajectory, to use to predict the target. + The encoder first encodes individual frames for both context and target, + and then, for the context, builds up a recurrent representation of all + prior frames in the same trajectory, to use to predict the target. By default, augments only the context, but can be modified to augment both context and target. """ @@ -100,7 +109,6 @@ class MoCoWithProjection(RepresentationLearner): Includes an additional projection head atop the representation and before the prediction """ - def __init__(self, env, log_dir, **kwargs): hardcoded_params = DEFAULT_HARDCODED_PARAMS + ['batch_extender'] kwargs = self.validate_and_update_kwargs(kwargs, hardcoded_params=hardcoded_params) @@ -119,15 +127,16 @@ class DynamicsPrediction(RepresentationLearner): def __init__(self, env, log_dir, **kwargs): kwargs_updates = {'target_pair_constructor_kwargs': {'mode': 'dynamics'}} kwargs = self.validate_and_update_kwargs(kwargs, kwargs_updates=kwargs_updates) - super().__init__(env=env, - log_dir=log_dir, - encoder=DynamicsEncoder, - # Should be a pixel decoder that takes in action, currently errors - decoder=NoOp, - loss_calculator=MSELoss, - augmenter=AugmentContextOnly, - target_pair_constructor=TemporalOffsetPairConstructor, - **kwargs) + super().__init__( + env=env, + log_dir=log_dir, + encoder=DynamicsEncoder, + # Should be a pixel decoder that takes in action, currently errors + decoder=NoOp, + loss_calculator=MSELoss, + augmenter=AugmentContextOnly, + target_pair_constructor=TemporalOffsetPairConstructor, + **kwargs) def learn(self, dataset, training_epochs): raise NotImplementedError("DynamicsPrediction is not yet fully implemented") @@ -138,15 +147,16 @@ def __init__(self, env, log_dir, **kwargs): kwargs_updates = {'target_pair_constructor_kwargs': {'mode': 'inverse_dynamics'}} kwargs = self.validate_and_update_kwargs(kwargs, kwargs_updates=kwargs_updates) - super().__init__(env=env, - log_dir=log_dir, - encoder=InverseDynamicsEncoder, - # Should be a action decoder that takes in next obs representation - decoder=NoOp, - loss_calculator=MSELoss, - augmenter=AugmentContextOnly, - target_pair_constructor=TemporalOffsetPairConstructor, - **kwargs) + super().__init__( + env=env, + log_dir=log_dir, + encoder=InverseDynamicsEncoder, + # Should be a action decoder that takes in next obs representation + decoder=NoOp, + loss_calculator=MSELoss, + augmenter=AugmentContextOnly, + target_pair_constructor=TemporalOffsetPairConstructor, + **kwargs) def learn(self, dataset, training_epochs): raise NotImplementedError("InverseDynamicsPrediction is not yet fully implemented") @@ -186,6 +196,7 @@ def __init__(self, env, log_dir, **kwargs): target_pair_constructor=TemporalOffsetPairConstructor, **kwargs) + class FixedVarianceCEB(RepresentationLearner): """ CEB with fixed rather than learned variance @@ -201,6 +212,7 @@ def __init__(self, env, log_dir, **kwargs): target_pair_constructor=TemporalOffsetPairConstructor, **kwargs) + class FixedVarianceTargetProjectedCEB(RepresentationLearner): """ """ @@ -218,17 +230,25 @@ def __init__(self, env, log_dir, **kwargs): class ActionConditionedTemporalCPC(RepresentationLearner): """ - Implementation of reinforcement-learning-specific variant of Temporal CPC which adds a projection layer on top - of the learned representation which integrates an encoding of the actions taken between time (t) and whatever - time (t+k) is specified in temporal_offset and used for pulling out the target frame. This, notionally, allows - the algorithm to construct frame representations that are action-independent, rather than marginalizing over an - expected policy, as might need to happen if the algorithm needed to predict the frame at time (t+k) over any - possible action distribution. + Implementation of reinforcement-learning-specific variant of Temporal CPC + which adds a projection layer on top of the learned representation which + integrates an encoding of the actions taken between time (t) and whatever + time (t+k) is specified in temporal_offset and used for pulling out the + target frame. This, notionally, allows the algorithm to construct frame + representations that are action-independent, rather than marginalizing over + an expected policy, as might need to happen if the algorithm needed to + predict the frame at time (t+k) over any possible action distribution. """ def __init__(self, env, log_dir, **kwargs): - kwargs_updates = {'preprocess_extra_context': False, - 'target_pair_constructor_kwargs': {"mode": "dynamics"}, - 'decoder_kwargs': {'action_space': env.action_space}} + kwargs_updates = { + 'preprocess_extra_context': False, + 'target_pair_constructor_kwargs': { + "mode": "dynamics" + }, + 'decoder_kwargs': { + 'action_space': env.action_space + } + } kwargs = self.validate_and_update_kwargs(kwargs, kwargs_updates=kwargs_updates) super().__init__(env=env, @@ -239,5 +259,6 @@ def __init__(self, env, log_dir, **kwargs): loss_calculator=BatchAsymmetricContrastiveLoss, **kwargs) + ## Algos that should not be run in all-algo test because they are not yet finished WIP_ALGOS = [DynamicsPrediction, InverseDynamicsPrediction] diff --git a/src/il_representations/algos/augmenters.py b/src/il_representations/algos/augmenters.py index b9365414..e2141ffc 100644 --- a/src/il_representations/algos/augmenters.py +++ b/src/il_representations/algos/augmenters.py @@ -1,21 +1,16 @@ -import enum -from torchvision import transforms -from imitation.augment.color import ColorSpace # noqa: F401 -from imitation.augment.convenience import StandardAugmentations -from il_representations.algos.utils import gaussian_blur -import torch -from abc import ABC, abstractmethod -import PIL """ These are pretty basic: when constructed, they take in a list of augmentations, and either augment just the context, or both the context and the target, depending on the algorithm. """ +from abc import ABC, abstractmethod + +from imitation.augment.color import ColorSpace # noqa: F401 +from imitation.augment.convenience import StandardAugmentations class Augmenter(ABC): def __init__(self, augmenter_spec, color_space): - augment_op = StandardAugmentations.from_string_spec( - augmenter_spec, color_space) + augment_op = StandardAugmentations.from_string_spec(augmenter_spec, color_space) self.augment_op = augment_op @abstractmethod diff --git a/src/il_representations/algos/base_learner.py b/src/il_representations/algos/base_learner.py index 6ff9ca98..3143f8ac 100644 --- a/src/il_representations/algos/base_learner.py +++ b/src/il_representations/algos/base_learner.py @@ -1,4 +1,5 @@ import gym + from il_representations.algos.utils import set_global_seeds @@ -13,12 +14,10 @@ def __init__(self, env): # if EncoderSimplePolicyHead is refactored. if isinstance(self.action_space, gym.spaces.Discrete): self.action_size = env.action_space.n - elif (isinstance(self.action_space, gym.spaces.Box) - and len(self.action_space.shape) == 1): + elif (isinstance(self.action_space, gym.spaces.Box) and len(self.action_space.shape) == 1): self.action_size, = self.action_space.shape else: - raise NotImplementedError( - f"can't handle action space {self.action_space}") + raise NotImplementedError(f"can't handle action space {self.action_space}") def set_random_seed(self, seed): if seed is None: diff --git a/src/il_representations/algos/batch_extenders.py b/src/il_representations/algos/batch_extenders.py index a721485d..e3ad4b94 100644 --- a/src/il_representations/algos/batch_extenders.py +++ b/src/il_representations/algos/batch_extenders.py @@ -1,14 +1,17 @@ +""" +BatchExtenders are used in situations where you want to pass a batch forward +for loss that is different than the batch seen by your encoder. The currently +implemented situation where this is the case is Momentum, where you want to +pass forward a bunch of negatives from prior encoding runs to increase the +difficulty of your prediction task. One might also imagine this being useful +for doing trajectory-mixing in a RNN case where batches naturally need to be +all from a small number of trajectories, but this isn't yet implemented. +""" from abc import ABC, abstractmethod + import torch -from torch.distributions import Normal + from il_representations.algos.utils import independent_multivariate_normal -""" -BatchExtenders are used in situations where you want to pass a batch forward for loss that is different than the -batch seen by your encoder. The currently implemented situation where this is the case is Momentum, where you want -to pass forward a bunch of negatives from prior encoding runs to increase the difficulty of your prediction task. -One might also imagine this being useful for doing trajectory-mixing in a RNN case where batches naturally need -to be all from a small number of trajectories, but this isn't yet implemented. -""" class BatchExtender(ABC): @@ -34,18 +37,22 @@ def __init__(self, queue_dim, device, queue_size=8192, sample=False): self.queue_ptr = 0 def __call__(self, context_dist, target_dist): - # Call up current contents of the queue, duplicate. Add targets to the queue, - # potentially overriding old information in the process. Return targets concatenated to contents of queue + # Call up current contents of the queue, duplicate. Add targets to the + # queue, potentially overriding old information in the process. Return + # targets concatenated to contents of queue targets_loc = target_dist.loc targets_covariance = target_dist.covariance_matrix - # Pull out the diagonals of our MultivariateNormal covariance matrices, so we don't store all the extra 0s - targets_scale = torch.stack([batch_element_matrix.diag() for batch_element_matrix in targets_covariance]) + # Pull out the diagonals of our MultivariateNormal covariance matrices, + # so we don't store all the extra 0s + targets_scale = torch.stack( + [batch_element_matrix.diag() for batch_element_matrix in targets_covariance]) batch_size = targets_loc.shape[0] queue_targets_scale = (self.queue_scale.clone().detach()).to(self.device) queue_targets_loc = (self.queue_loc.clone().detach()).to(self.device) - # TODO: Currently requires the queue size to be a multiple of the batch size. Don't require that. + # TODO: Currently requires the queue size to be a multiple of the batch + # size. Don't require that. self.queue_loc[self.queue_ptr:self.queue_ptr + batch_size] = targets_loc self.queue_scale[self.queue_ptr:self.queue_ptr + batch_size] = targets_scale self.queue_ptr = (self.queue_ptr + batch_size) % self.queue_size diff --git a/src/il_representations/algos/decoders.py b/src/il_representations/algos/decoders.py index d70e2651..9398d97d 100644 --- a/src/il_representations/algos/decoders.py +++ b/src/il_representations/algos/decoders.py @@ -1,31 +1,36 @@ -import functools -import torch.nn as nn +""" +LossDecoders are meant to be mappings between the representation being learned, +and the representation or tensor that is fed directly into the loss. In many +cases, these are the same, and this will just be a NoOp. + +Some cases where it is different: +- When you are using a Projection Head in your contrastive loss, and comparing + similarities of vectors that are k >=1 nonlinear layers downstream from the + actual representation you'll use in later tasks +- When you're learning a VAE, and the loss is determined by how effectively you + can reconstruct the image from a representation vector, the LossDecoder will + handle that representation -> image mapping +- When you're predicting actions given current and next state, you'll want to + predict those actions given both the representation of the current state, and + also information about the next state. This occasional need for extra + information beyond the central context state is why we have `extra_context` + as an optional bit of data that pair constructors can return, to be passed + forward for use here +""" import copy -import torch -import torch.nn.functional as F -from torch.distributions import MultivariateNormal -from il_representations.algos.utils import independent_multivariate_normal +import functools + import gym.spaces as spaces import numpy as np +import torch +from torch.distributions import MultivariateNormal +import torch.nn as nn +import torch.nn.functional as F +from il_representations.algos.utils import independent_multivariate_normal -""" -LossDecoders are meant to be mappings between the representation being learned, -and the representation or tensor that is fed directly into the loss. In many cases, these are the -same, and this will just be a NoOp. - -Some cases where it is different: -- When you are using a Projection Head in your contrastive loss, and comparing similarities of vectors that are -k >=1 nonlinear layers downstream from the actual representation you'll use in later tasks -- When you're learning a VAE, and the loss is determined by how effectively you can reconstruct the image -from a representation vector, the LossDecoder will handle that representation -> image mapping -- When you're predicting actions given current and next state, you'll want to predict those actions given -both the representation of the current state, and also information about the next state. This occasional -need for extra information beyond the central context state is why we have `extra_context` as an optional -bit of data that pair constructors can return, to be passed forward for use here -""" +# TODO change shape to dim throughout this file and the code -#TODO change shape to dim throughout this file and the code class LossDecoder(nn.Module): def __init__(self, representation_dim, projection_shape, sample=False): @@ -51,7 +56,10 @@ def get_vector(self, z_dist): return z_dist.loc def ones_like_projection_dim(self, x): - return torch.ones(size=(x.shape[0], self.projection_dim,), device=x.device) + return torch.ones(size=( + x.shape[0], + self.projection_dim, + ), device=x.device) class NoOp(LossDecoder): @@ -63,7 +71,8 @@ class TargetProjection(LossDecoder): def __init__(self, representation_dim, projection_shape, sample=False, learn_scale=False): super(TargetProjection, self).__init__(representation_dim, projection_shape, sample) - self.target_projection = nn.Sequential(nn.Linear(self.representation_dim, self.projection_dim)) + self.target_projection = nn.Sequential( + nn.Linear(self.representation_dim, self.projection_dim)) def decode_context(self, z_dist, traj_info, extra_context=None): return z_dist @@ -71,17 +80,16 @@ def decode_context(self, z_dist, traj_info, extra_context=None): def decode_target(self, z_dist, traj_info, extra_context=None): z_vector = self.get_vector(z_dist) mean = self.target_projection(z_vector) - return torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=z_dist.covariance_matrix) + return torch.distributions.MultivariateNormal(loc=mean, + covariance_matrix=z_dist.covariance_matrix) class ProjectionHead(LossDecoder): def __init__(self, representation_dim, projection_shape, sample=False, learn_scale=False): super(ProjectionHead, self).__init__(representation_dim, projection_shape, sample) - self.shared_mlp = nn.Sequential(nn.Linear(self.representation_dim, 256), - nn.ReLU(), - nn.Linear(256, 256), - nn.ReLU()) + self.shared_mlp = nn.Sequential(nn.Linear(self.representation_dim, 256), nn.ReLU(), + nn.Linear(256, 256), nn.ReLU()) self.mean_layer = nn.Linear(256, self.projection_dim) if learn_scale: @@ -92,14 +100,24 @@ def __init__(self, representation_dim, projection_shape, sample=False, learn_sca def forward(self, z_dist, traj_info, extra_context=None): z = self.get_vector(z_dist) shared_repr = self.shared_mlp(z) - return independent_multivariate_normal(loc=self.mean_layer(shared_repr), scale=torch.exp(self.scale_layer(shared_repr))) + return independent_multivariate_normal(loc=self.mean_layer(shared_repr), + scale=torch.exp(self.scale_layer(shared_repr))) class MomentumProjectionHead(LossDecoder): - def __init__(self, representation_dim, projection_shape, sample=False, momentum_weight=0.99, learn_scale=False): - super(MomentumProjectionHead, self).__init__(representation_dim, projection_shape, sample=sample) - self.context_decoder = ProjectionHead(representation_dim, projection_shape, - sample=sample, learn_scale=learn_scale) + def __init__(self, + representation_dim, + projection_shape, + sample=False, + momentum_weight=0.99, + learn_scale=False): + super(MomentumProjectionHead, self).__init__(representation_dim, + projection_shape, + sample=sample) + self.context_decoder = ProjectionHead(representation_dim, + projection_shape, + sample=sample, + learn_scale=learn_scale) self.target_decoder = copy.deepcopy(self.context_decoder) for param in self.target_decoder.parameters(): param.requires_grad = False @@ -110,32 +128,39 @@ def decode_context(self, z_dist, traj_info, extra_context=None): def decode_target(self, z_dist, traj_info, extra_context=None): """ - Encoder target/keys using momentum-updated key encoder. Had some thought of making _momentum_update_key_encoder - a backwards hook, but seemed overly complex for an initial POC + Encoder target/keys using momentum-updated key encoder. Had some + thought of making _momentum_update_key_encoder a backwards hook, but + seemed overly complex for an initial POC :param x: :return: """ with torch.no_grad(): self._momentum_update_key_encoder() decoded_z_dist = self.target_decoder(z_dist, traj_info, extra_context=extra_context) - return MultivariateNormal(loc=decoded_z_dist.loc.detach(), covariance_matrix=decoded_z_dist.covariance_matrix.detach()) + return MultivariateNormal(loc=decoded_z_dist.loc.detach(), + covariance_matrix=decoded_z_dist.covariance_matrix.detach()) @torch.no_grad() def _momentum_update_key_encoder(self): - for param_q, param_k in zip(self.context_decoder.parameters(), self.target_decoder.parameters()): - param_k.data = param_k.data * self.momentum_weight + param_q.data * (1. - self.momentum_weight) + for param_q, param_k in zip(self.context_decoder.parameters(), + self.target_decoder.parameters()): + param_k.data = param_k.data * self.momentum_weight + param_q.data * ( + 1. - self.momentum_weight) class BYOLProjectionHead(MomentumProjectionHead): def __init__(self, representation_dim, projection_shape, momentum_weight=0.99, sample=False): - super(BYOLProjectionHead, self).__init__(representation_dim, projection_shape, - sample=sample, momentum_weight=momentum_weight) + super(BYOLProjectionHead, self).__init__(representation_dim, + projection_shape, + sample=sample, + momentum_weight=momentum_weight) self.context_predictor = ProjectionHead(projection_shape, projection_shape) def forward(self, z_dist, traj_info, extra_context=None): internal_dist = super().forward(z_dist, traj_info, extra_context=extra_context) prediction_dist = self.context_predictor(internal_dist, traj_info, extra_context=None) - return independent_multivariate_normal(loc=F.normalize(prediction_dist.loc, dim=1), scale=prediction_dist.scale) + return independent_multivariate_normal(loc=F.normalize(prediction_dist.loc, dim=1), + scale=prediction_dist.scale) def decode_target(self, z_dist, traj_info, extra_context=None): with torch.no_grad(): @@ -145,39 +170,60 @@ def decode_target(self, z_dist, traj_info, extra_context=None): class ActionConditionedVectorDecoder(LossDecoder): - def __init__(self, representation_dim, projection_shape, action_space, sample=False, action_encoding_dim=128, - action_encoder_layers=1, learn_scale=False, action_embedding_dim=5, use_lstm=False): - super(ActionConditionedVectorDecoder, self).__init__(representation_dim, projection_shape, sample=sample) + def __init__(self, + representation_dim, + projection_shape, + action_space, + sample=False, + action_encoding_dim=128, + action_encoder_layers=1, + learn_scale=False, + action_embedding_dim=5, + use_lstm=False): + super(ActionConditionedVectorDecoder, self).__init__(representation_dim, + projection_shape, + sample=sample) self.learn_scale = learn_scale - # Machinery for turning raw actions into vectors. If actions are discrete, this is done via an Embedding. - # If actions are continuous/box, this is done via a simple flattening. + # Machinery for turning raw actions into vectors. If actions are + # discrete, this is done via an Embedding. If actions are + # continuous/box, this is done via a simple flattening. if isinstance(action_space, spaces.Discrete): - self.action_processor = nn.Embedding(num_embeddings=action_space.n, embedding_dim=action_embedding_dim) + self.action_processor = nn.Embedding(num_embeddings=action_space.n, + embedding_dim=action_embedding_dim) processed_action_dim = action_embedding_dim self.action_shape = () # discrete actions are just numbers elif isinstance(action_space, spaces.Box): - self.action_processor = functools.partial(torch.flatten, - start_dim=2) + self.action_processor = functools.partial(torch.flatten, start_dim=2) processed_action_dim = np.prod(action_space.shape) self.action_shape = action_space.shape else: - raise NotImplementedError("Action conditioning is only currently implemented for Discrete and Box action spaces") + raise NotImplementedError( + "Action conditioning is only currently implemented for Discrete and Box " + "action spaces") - # Machinery for aggregating information from an arbitrary number of actions into a single vector, - # either through a LSTM, or by simply averaging the vector representations of the k states together + # Machinery for aggregating information from an arbitrary number of + # actions into a single vector, either through a LSTM, or by simply + # averaging the vector representations of the k states together if use_lstm: - self.action_encoder = nn.LSTM(processed_action_dim, action_encoding_dim, action_encoder_layers, batch_first=True) + self.action_encoder = nn.LSTM(processed_action_dim, + action_encoding_dim, + action_encoder_layers, + batch_first=True) else: self.action_encoder = None action_encoding_dim = processed_action_dim - # Machinery for mapping a concatenated (context representation, action representation) into a projection - self.action_conditioned_projection = nn.Linear(representation_dim + action_encoding_dim, projection_shape) + # Machinery for mapping a concatenated (context representation, action + # representation) into a projection + self.action_conditioned_projection = nn.Linear(representation_dim + action_encoding_dim, + projection_shape) - # If learning scale/std deviation parameter, declare a layer for that, otherwise, return a unit-constant vector + # If learning scale/std deviation parameter, declare a layer for that, + # otherwise, return a unit-constant vector if self.learn_scale: - self.scale_projection = nn.Linear(representation_dim + action_encoding_dim, projection_shape) + self.scale_projection = nn.Linear(representation_dim + action_encoding_dim, + projection_shape) else: self.scale_projection = self.ones_like_projection_dim @@ -209,9 +255,9 @@ def decode_context(self, z_dist, traj_info, extra_context=None): assert action_encoding_vector.shape[0] == batch_dim, \ action_encoding_vector.shape - # Concatenate context representation and action representation and map to a merged representation + # Concatenate context representation and action representation and map + # to a merged representation merged_vector = torch.cat([z, action_encoding_vector], dim=1) mean_projection = self.action_conditioned_projection(merged_vector) scale = self.scale_projection(merged_vector) return independent_multivariate_normal(loc=mean_projection, scale=scale) - diff --git a/src/il_representations/algos/encoders.py b/src/il_representations/algos/encoders.py index 1ca15034..75549400 100644 --- a/src/il_representations/algos/encoders.py +++ b/src/il_representations/algos/encoders.py @@ -1,26 +1,23 @@ -import torch -import torch.nn as nn +""" +Encoders conceptually serve as the bit of the representation learning +architecture that learns the representation itself (except in RNN cases, where +encoders only learn the per-frame representation). + +The only real complex thing to note here is the MomentumEncoder architecture, +which creates two CNNEncoders, and updates weights of one as a slowly moving +average of the other. Note that this bit of momentum is separated from the +creation and filling of a queue of representations, which is handled by the +BatchExtender module +""" import copy -from torch.distributions import MultivariateNormal -from functools import reduce -import numpy as np -from stable_baselines3.common.preprocessing import preprocess_obs -from gym.spaces import Box -from il_representations.algos.utils import independent_multivariate_normal import numpy as np +from stable_baselines3.common.preprocessing import preprocess_obs import torch from torch import nn +from torch.distributions import MultivariateNormal - -""" -Encoders conceptually serve as the bit of the representation learning architecture that learns the representation itself -(except in RNN cases, where encoders only learn the per-frame representation). - -The only real complex thing to note here is the MomentumEncoder architecture, which creates two CNNEncoders, -and updates weights of one as a slowly moving average of the other. Note that this bit of momentum is separated -from the creation and filling of a queue of representations, which is handled by the BatchExtender module -""" +from il_representations.algos.utils import independent_multivariate_normal def compute_output_shape(observation_space, layers): @@ -65,13 +62,28 @@ def __init__(self, observation_space, representation_dim): # first apply convolution layers + flattening conv_arch = [ - {'out_dim': 32, 'kernel_size': 8, 'stride': 4}, - {'out_dim': 64, 'kernel_size': 4, 'stride': 2}, - {'out_dim': 64, 'kernel_size': 3, 'stride': 1}, + { + 'out_dim': 32, + 'kernel_size': 8, + 'stride': 4 + }, + { + 'out_dim': 64, + 'kernel_size': 4, + 'stride': 2 + }, + { + 'out_dim': 64, + 'kernel_size': 3, + 'stride': 1 + }, ] for layer_spec in conv_arch: - shared_network_layers.append(nn.Conv2d(self.input_channel, layer_spec['out_dim'], - kernel_size=layer_spec['kernel_size'], stride=layer_spec['stride'])) + shared_network_layers.append( + nn.Conv2d(self.input_channel, + layer_spec['out_dim'], + kernel_size=layer_spec['kernel_size'], + stride=layer_spec['stride'])) shared_network_layers.append(nn.ReLU()) self.input_channel = layer_spec['out_dim'] shared_network_layers.append(nn.Flatten()) @@ -80,7 +92,9 @@ def __init__(self, observation_space, representation_dim): dense_in_dim, = compute_output_shape(observation_space, shared_network_layers) dense_arch = [ # this input size is accurate for Atari, but will be ovewritten for other envs - {'in_dim': 64*7*7}, + { + 'in_dim': 64 * 7 * 7 + }, ] dense_arch[0]['in_dim'] = dense_in_dim dense_arch[-1]['out_dim'] = representation_dim @@ -102,30 +116,30 @@ def forward(self, x): class MAGICALCNN(nn.Module): """The CNN from the MAGICAL paper.""" - def __init__(self, - observation_space, - representation_dim, - # TODO(sam): enable BN by default once I'm sure that .train() - # and .eval() are used correctly throughout the codebase. - use_bn=False, - use_ln=False, - dropout=None, - use_sn=False, - width=2, - fc_dim=128, - ActivationCls=torch.nn.ReLU): + def __init__( + self, + observation_space, + representation_dim, + # TODO(sam): enable BN by default once I'm sure that .train() + # and .eval() are used correctly throughout the codebase. + use_bn=False, + use_ln=False, + dropout=None, + use_sn=False, + width=2, + fc_dim=128, + ActivationCls=torch.nn.ReLU): super().__init__() def conv_block(in_chans, out_chans, kernel_size, stride, padding): # We sometimes disable bias because batch norm has its own bias. - conv_layer = nn.Conv2d( - in_chans, - out_chans, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=not use_bn, - padding_mode='zeros') + conv_layer = nn.Conv2d(in_chans, + out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=not use_bn, + padding_mode='zeros') if use_sn: # apply spectral norm if necessary @@ -234,12 +248,19 @@ def encode_extra_context(self, x, traj_info): class DeterministicEncoder(Encoder): - def __init__(self, obs_space, representation_dim, obs_encoder_cls=None, scale_constant=1, **kwargs): + def __init__(self, + obs_space, + representation_dim, + obs_encoder_cls=None, + scale_constant=1, + **kwargs): """ :param obs_space: The observation space that this Encoder will be used on - :param representation_dim: The number of dimensions of the representation that will be learned - :param obs_encoder_cls: An internal architecture implementing `forward` to return a single vector - representing the mean representation z of a fixed-variance representation distribution + :param representation_dim: The number of dimensions of the + representation that will be learned + :param obs_encoder_cls: An internal architecture implementing `forward` + to return a single vector representing the mean representation z of + a fixed-variance representation distribution """ super().__init__(**kwargs) obs_encoder_cls = get_obs_encoder_cls(obs_encoder_cls) @@ -252,15 +273,23 @@ def forward(self, x, traj_info): class StochasticEncoder(Encoder): - def __init__(self, obs_space, representation_dim, obs_encoder_cls=None, latent_dim=None, **kwargs): + def __init__(self, + obs_space, + representation_dim, + obs_encoder_cls=None, + latent_dim=None, + **kwargs): """ :param obs_space: The observation space that this Encoder will be used on - :param representation_dim: The number of dimensions of the representation that will be learned + :param representation_dim: The number of dimensions of the + representation that will be learned :param obs_encoder_cls: An internal architecture implementing `forward` to return a single vector. This is expected NOT to end in a ReLU (i.e. final layer should be linear). :param latent_dim: Dimension of the latents that feed into mean and std networks (default: representation_dim * 2). - two vectors, representing the mean AND learned standard deviation of a representation distribution + + two vectors, representing the mean AND learned standard deviation of a + representation distribution """ super().__init__(**kwargs) obs_encoder_cls = get_obs_encoder_cls(obs_encoder_cls) @@ -289,8 +318,13 @@ def encode_extra_context(self, x, traj_info): class MomentumEncoder(Encoder): - def __init__(self, obs_shape, representation_dim, learn_scale=False, - momentum_weight=0.999, obs_encoder_cls=None, **kwargs): + def __init__(self, + obs_shape, + representation_dim, + learn_scale=False, + momentum_weight=0.999, + obs_encoder_cls=None, + **kwargs): super().__init__(**kwargs) obs_encoder_cls = get_obs_encoder_cls(obs_encoder_cls) if learn_scale: @@ -308,35 +342,49 @@ def forward(self, x, traj_info): def encode_target(self, x, traj_info): """ - Encoder target/keys using momentum-updated key encoder. Had some thought of making _momentum_update_key_encoder - a backwards hook, but seemed overly complex for an initial proof of concept + Encoder target/keys using momentum-updated key encoder. Had some + thought of making _momentum_update_key_encoder a backwards hook, but + seemed overly complex for an initial proof of concept :param x: :return: """ with torch.no_grad(): self._momentum_update_key_encoder() z_dist = self.key_encoder(x, traj_info) - return MultivariateNormal(loc=z_dist.loc.detach(), covariance_matrix=z_dist.covariance_matrix.detach()) + return MultivariateNormal(loc=z_dist.loc.detach(), + covariance_matrix=z_dist.covariance_matrix.detach()) @torch.no_grad() def _momentum_update_key_encoder(self): for param_q, param_k in zip(self.query_encoder.parameters(), self.key_encoder.parameters()): - param_k.data = param_k.data * self.momentum_weight + param_q.data * (1. - self.momentum_weight) + param_k.data = param_k.data * self.momentum_weight + param_q.data * ( + 1. - self.momentum_weight) class RecurrentEncoder(Encoder): - def __init__(self, obs_shape, representation_dim, learn_scale=False, num_recurrent_layers=2, - single_frame_repr_dim=None, min_traj_size=5, obs_encoder_cls=None, rnn_output_dim=64, **kwargs): + def __init__(self, + obs_shape, + representation_dim, + learn_scale=False, + num_recurrent_layers=2, + single_frame_repr_dim=None, + min_traj_size=5, + obs_encoder_cls=None, + rnn_output_dim=64, + **kwargs): super().__init__(**kwargs) obs_encoder_cls = get_obs_encoder_cls(obs_encoder_cls) self.num_recurrent_layers = num_recurrent_layers self.min_traj_size = min_traj_size self.representation_dim = representation_dim - self.single_frame_repr_dim = representation_dim if single_frame_repr_dim is None else single_frame_repr_dim + self.single_frame_repr_dim = representation_dim if single_frame_repr_dim is None \ + else single_frame_repr_dim self.single_frame_encoder = DeterministicEncoder(obs_shape, self.single_frame_repr_dim, obs_encoder_cls) - self.context_rnn = nn.LSTM(self.single_frame_repr_dim, rnn_output_dim, - self.num_recurrent_layers, batch_first=True) + self.context_rnn = nn.LSTM(self.single_frame_repr_dim, + rnn_output_dim, + self.num_recurrent_layers, + batch_first=True) self.mean_layer = nn.Linear(rnn_output_dim, self.representation_dim) if learn_scale: self.scale_layer = nn.Linear(rnn_output_dim, self.representation_dim) @@ -344,31 +392,38 @@ def __init__(self, obs_shape, representation_dim, learn_scale=False, num_recurre self.scale_layer = self.ones_like_representation_dim def ones_like_representation_dim(self, x): - return torch.ones(size=(x.shape[0], self.representation_dim,), device=x.device) + return torch.ones(size=( + x.shape[0], + self.representation_dim, + ), device=x.device) def _reshape_and_stack(self, z, traj_info): batch_size = z.shape[0] input_shape = z.shape[1:] trajectory_id, timesteps = traj_info # We should have trajectory_id values for every element in the batch z - assert len(z) == len(trajectory_id), "Every element in z must have a trajectory ID in a RecurrentEncoder" + assert len(z) == len( + trajectory_id), "Every element in z must have a trajectory ID in a RecurrentEncoder" # A set of all distinct trajectory IDs trajectories = torch.unique(trajectory_id) padded_trajectories = [] mask_lengths = [] for trajectory in trajectories: traj_timesteps = timesteps[trajectory_id == trajectory] - assert list(traj_timesteps) == sorted(list(traj_timesteps)), "Batches must be sorted to use a RecurrentEncoder" - # Get all Z vectors associated with a trajectory, which have now been confirmed to be sorted timestep-wise + assert list(traj_timesteps) == sorted( + list(traj_timesteps)), "Batches must be sorted to use a RecurrentEncoder" + # Get all Z vectors associated with a trajectory, which have now + # been confirmed to be sorted timestep-wise traj_z = z[trajectory_id == trajectory] # Keep track of how many actual unpadded values were in the trajectory mask_lengths.append(traj_z.shape[0]) pad_size = batch_size - traj_z.shape[0] - padding = torch.zeros((pad_size,) + input_shape).to(self.device) + padding = torch.zeros((pad_size, ) + input_shape).to(self.device) padded_z = torch.cat([traj_z, padding]) padded_trajectories.append(padded_z) - assert np.mean(mask_lengths) > self.min_traj_size, f"Batches must contain trajectories with an average " \ - f"length above {self.min_traj_size}. Trajectories found: {traj_info}" + assert np.mean(mask_lengths) > self.min_traj_size, \ + "Batches must contain trajectories with an average " \ + "length above {self.min_traj_size}. Trajectories found: {traj_info}" stacked_trajectories = torch.stack(padded_trajectories, dim=0) return stacked_trajectories, mask_lengths @@ -380,7 +435,8 @@ def forward(self, x, traj_info): z = self.single_frame_encoder(x, traj_info).loc stacked_trajectories, mask_lengths = self._reshape_and_stack(z, traj_info) hiddens, final = self.context_rnn(stacked_trajectories) - # Pull out only the hidden states corresponding to actual non-padding inputs, and concat together + # Pull out only the hidden states corresponding to actual non-padding + # inputs, and concat together masked_hiddens = [] for i, trajectory_length in enumerate(mask_lengths): masked_hiddens.append(hiddens[i][:trajectory_length]) @@ -389,4 +445,3 @@ def forward(self, x, traj_info): mean = self.mean_layer(flattened_hiddens) scale = self.scale_layer(flattened_hiddens) return independent_multivariate_normal(loc=mean, scale=scale) - diff --git a/src/il_representations/algos/losses.py b/src/il_representations/algos/losses.py index 5708f335..40465050 100644 --- a/src/il_representations/algos/losses.py +++ b/src/il_representations/algos/losses.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod + +import stable_baselines3.common.logger as sb_logger import torch import torch.nn.functional as F -import stable_baselines3.common.logger as sb_logger class RepresentationLoss(ABC): @@ -14,9 +15,11 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist): pass def get_vector_forms(self, decoded_context_dist, target_dist, encoded_context_dist): - decoded_contexts = decoded_context_dist.sample() if self.sample else decoded_context_dist.loc + decoded_contexts = decoded_context_dist.sample( + ) if self.sample else decoded_context_dist.loc targets = target_dist.sample() if self.sample else target_dist.loc - encoded_contexts = encoded_context_dist.sample() if self.sample else encoded_context_dist.loc + encoded_contexts = encoded_context_dist.sample( + ) if self.sample else encoded_context_dist.loc return decoded_contexts, targets, encoded_contexts @@ -30,13 +33,17 @@ def __init__(self, device, sample=False, temp=0.1, normalize=True): self.criterion = torch.nn.CrossEntropyLoss() self.temp = temp - # Most methods use either cosine similarity or matrix multiplication similarity. Since cosine similarity equals - # taking MatMul on normalized vectors, setting normalize=True is equivalent to using torch.CosineSimilarity(). + # Most methods use either cosine similarity or matrix multiplication + # similarity. Since cosine similarity equals taking MatMul on + # normalized vectors, setting normalize=True is equivalent to using + # torch.CosineSimilarity(). self.normalize = normalize - # Sometimes the calculated vectors may contain an image's similarity with itself, which can be a large number. - # Since we mostly care about maximizing an image's similarity with its augmented version, we subtract a large - # number to make the classification have ~0 probability picking the original image itself. + # Sometimes the calculated vectors may contain an image's similarity + # with itself, which can be a large number. Since we mostly care about + # maximizing an image's similarity with its augmented version, we + # subtract a large number to make the classification have ~0 + # probability picking the original image itself. self.large_num = 1e9 def calculate_logits_and_labels(self, z_i, z_j, mask): @@ -46,7 +53,8 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None) # decoded_context -> representation of context + optional projection head # target -> representation of target + optional projection head # encoded_context -> not used by this loss - decoded_contexts, targets, _ = self.get_vector_forms(decoded_context_dist, target_dist, encoded_context_dist) + decoded_contexts, targets, _ = self.get_vector_forms(decoded_context_dist, target_dist, + encoded_context_dist) z_i = decoded_contexts z_j = targets @@ -65,15 +73,16 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None) class QueueAsymmetricContrastiveLoss(AsymmetricContrastiveLoss): """ - This implements algorithms that use a queue to maintain all the negative examples. The contrastive loss is - calculated through the comparison of an image with its augmented version (positive example) and everything else - in the queue (negative examples). The method used in MoCo. - - Alternatively, for higher sample efficiency, one may use (1) current batch's augmented images, (2) current batch's - original images, and (3) all the images in the queue as negative examples. This is implemented with setting - use_batch_neg=True. + This implements algorithms that use a queue to maintain all the negative + examples. The contrastive loss is calculated through the comparison of an + image with its augmented version (positive example) and everything else in + the queue (negative examples). The method used in MoCo. + + Alternatively, for higher sample efficiency, one may use (1) current + batch's augmented images, (2) current batch's original images, and (3) all + the images in the queue as negative examples. This is implemented with + setting use_batch_neg=True. """ - def __init__(self, device, sample=False, temp=0.1, use_batch_neg=False): super(QueueAsymmetricContrastiveLoss, self).__init__(device, sample) @@ -89,7 +98,8 @@ def calculate_logits_and_labels(self, z_i, z_j, mask): queue = z_j[batch_size:] z_j = z_j[:batch_size] - # Calculate the dot product similarity of each image with all images in the queue. Return an NxK tensor. + # Calculate the dot product similarity of each image with all images in + # the queue. Return an NxK tensor. l_neg = torch.matmul(z_i, queue.T) # NxK if self.use_batch_neg: @@ -104,13 +114,16 @@ def calculate_logits_and_labels(self, z_i, z_j, mask): logits = torch.cat([logits_ab, logits_aa, l_neg], dim=1) # Nx(2N+K) - # The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of - # represent(image_i) and represent(augmented_image_i). + # The values we want to maximize lie on the i-th index of each row + # i. i.e. the dot product of represent(image_i) and + # represent(augmented_image_i). labels = torch.arange(batch_size, dtype=torch.long).to(self.device) else: - # torch.einsum provides an elegant way to calculate vector dot products across a batch. Each entry on the - # Nx1 returned tensor is a dot product of represent(image_i) and represent(augmented_image_i). + # torch.einsum provides an elegant way to calculate vector dot + # products across a batch. Each entry on the Nx1 returned tensor is + # a dot product of represent(image_i) and + # represent(augmented_image_i). l_pos = torch.einsum('nc,nc->n', [z_i, z_j]).unsqueeze(-1) # Nx1 # The negative examples here only contain image representations in the queue. @@ -124,9 +137,10 @@ def calculate_logits_and_labels(self, z_i, z_j, mask): class BatchAsymmetricContrastiveLoss(AsymmetricContrastiveLoss): """ - This applies to algorithms that performs asymmetric contrast with samples in the same batch. i.e. Negative examples - come from all other images (and their augmented versions) in the same batch. Represents InfoNCE used in original - CPC paper. + This applies to algorithms that performs asymmetric contrast with samples + in the same batch. i.e. Negative examples come from all other images (and + their augmented versions) in the same batch. Represents InfoNCE used in + original CPC paper. """ def __init__(self, device, sample=False, temp=0.1): super(BatchAsymmetricContrastiveLoss, self).__init__(device, sample) @@ -139,7 +153,8 @@ def calculate_logits_and_labels(self, z_i, z_j, mask): """ batch_size = z_i.shape[0] - # Similarity of the original images with all other original images in current batch. Return a matrix of NxN. + # Similarity of the original images with all other original images in + # current batch. Return a matrix of NxN. logits_aa = torch.matmul(z_i, z_i.T) # NxN # Values on the diagonal line are each image's similarity with itself @@ -150,38 +165,44 @@ def calculate_logits_and_labels(self, z_i, z_j, mask): logits = torch.cat((logits_ab, logits_aa), 1) # Nx2N - # The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of - # represent(image_i) and represent(augmented_image_i). + # The values we want to maximize lie on the i-th index of each row i. + # i.e. the dot product of represent(image_i) and + # represent(augmented_image_i). label = torch.arange(batch_size, dtype=torch.long).to(self.device) return logits, label class SymmetricContrastiveLoss(RepresentationLoss): """ - A contrastive loss that does prediction "in both directions," i.e. that calculates logits of IJ similarity against - all similarities with J, and also all similarities with I, and calculates cross-entropy on both + A contrastive loss that does prediction "in both directions," i.e. that + calculates logits of IJ similarity against all similarities with J, and + also all similarities with I, and calculates cross-entropy on both """ - def __init__(self, device, sample=False, temp=0.1, normalize=True): super(SymmetricContrastiveLoss, self).__init__(device, sample) self.criterion = torch.nn.CrossEntropyLoss() self.temp = temp - # Most methods use either cosine similarity or matrix multiplication similarity. Since cosine similarity equals - # taking MatMul on normalized vectors, setting normalize=True is equivalent to using torch.CosineSimilarity(). + # Most methods use either cosine similarity or matrix multiplication + # similarity. Since cosine similarity equals taking MatMul on + # normalized vectors, setting normalize=True is equivalent to using + # torch.CosineSimilarity(). self.normalize = normalize - # Sometimes the calculated vectors may contain an image's similarity with itself, which can be a large number. - # Since we mostly care about maximizing an image's similarity with its augmented version, we subtract a large - # number to make the classification have ~0 probability picking the original image itself. + # Sometimes the calculated vectors may contain an image's similarity + # with itself, which can be a large number. Since we mostly care about + # maximizing an image's similarity with its augmented version, we + # subtract a large number to make the classification have ~0 + # probability picking the original image itself. self.large_num = 1e9 def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None): # decoded_context -> representation of context + optional projection head # target -> representation of target + optional projection head # encoded_context -> not used by this loss - decoded_contexts, targets, _ = self.get_vector_forms(decoded_context_dist, target_dist, encoded_context_dist) + decoded_contexts, targets, _ = self.get_vector_forms(decoded_context_dist, target_dist, + encoded_context_dist) z_i = decoded_contexts z_j = targets batch_size = z_i.shape[0] @@ -192,7 +213,8 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None) mask = (torch.eye(batch_size) * self.large_num).to(self.device) - # Similarity of the original images with all other original images in current batch. Return a matrix of NxN. + # Similarity of the original images with all other original images in + # current batch. Return a matrix of NxN. logits_aa = torch.matmul(z_i, z_i.T) # NxN # Values on the diagonal line are each image's similarity with itself @@ -212,15 +234,17 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None) sb_logger.record('avg_other_similarity', avg_other_similarity) sb_logger.record('self_other_sim_delta', avg_self_similarity - avg_other_similarity) - # Each row now contains an image's similarity with the batch's augmented images & original images. This applies - # to both original and augmented images (hence "symmetric"). + # Each row now contains an image's similarity with the batch's + # augmented images & original images. This applies to both original and + # augmented images (hence "symmetric"). logits_i = torch.cat((logits_ab, logits_aa), 1) # Nx2N logits_j = torch.cat((logits_ba, logits_bb), 1) # Nx2N logits = torch.cat((logits_i, logits_j), axis=0) # 2Nx2N logits /= self.temp - # The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of - # represent(image_i) and represent(augmented_image_i). + # The values we want to maximize lie on the i-th index of each row i. + # i.e. the dot product of represent(image_i) and + # represent(augmented_image_i). label = torch.arange(batch_size, dtype=torch.long).to(self.device) labels = torch.cat((label, label), axis=0) @@ -233,16 +257,16 @@ def __init__(self, device, sample=False): self.criterion = torch.nn.MSELoss() def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None): - decoded_contexts, targets, _ = self.get_vector_forms(decoded_context_dist, target_dist, encoded_context_dist) + decoded_contexts, targets, _ = self.get_vector_forms(decoded_context_dist, target_dist, + encoded_context_dist) return self.criterion(decoded_contexts, targets) class CEBLoss(RepresentationLoss): """ - A variational contrastive loss that implements information bottlenecking, but in a less conservative form - than done by traditional VIB techniques + A variational contrastive loss that implements information bottlenecking, + but in a less conservative form than done by traditional VIB techniques """ - def __init__(self, device, beta=.1, sample=True): super().__init__(device, sample=sample) # TODO allow for beta functions @@ -252,14 +276,23 @@ def __init__(self, device, beta=.1, sample=True): def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None): z = decoded_context_dist.rsample() - log_ezx = decoded_context_dist.log_prob(z) # B -> Log proba of each vector in z under the distribution it was sampled from - log_bzy = target_dist.log_prob(z) # B -> Log proba of each vector in z under the distribution conditioned on its corresponding target - cross_probas_logits = torch.stack([target_dist.log_prob(z[i]) for i in range(z.shape[0])], dim=0) # BxB Log proba of each vector z[i] under _all_ target distributions - # The return shape of target_dist.log_prob(z[i]) is the probability of z[i] under each distribution in the batch - catgen = torch.distributions.Categorical(logits=cross_probas_logits) # logits of shape BxB -> Batch categorical, one distribution per element in z over possible - # targets/y values + # B -> Log proba of each vector in z under the distribution it was + # sampled from + log_ezx = decoded_context_dist.log_prob(z) + # B -> Log proba of each vector in z under the distribution conditioned + # on its corresponding target + log_bzy = target_dist.log_prob(z) + cross_probas_logits = torch.stack( + [target_dist.log_prob(z[i]) for i in range(z.shape[0])], + dim=0) # BxB Log proba of each vector z[i] under _all_ target distributions + # The return shape of target_dist.log_prob(z[i]) is the probability of + # z[i] under each distribution in the batch. + # logits of shape BxB -> Batch categorical, one distribution per element in z over possible + catgen = torch.distributions.Categorical(logits=cross_probas_logits) + # targets/y values inds = (torch.arange(start=0, end=len(z))).to(self.device) - i_yz = catgen.log_prob(inds) # The probability of the kth target under the kth Categorical distribution (probability of true y) - loss = torch.mean(self.beta*(log_ezx - log_bzy) - i_yz) + # The probability of the kth target under the kth Categorical + # distribution (probability of true y) + i_yz = catgen.log_prob(inds) + loss = torch.mean(self.beta * (log_ezx - log_bzy) - i_yz) return loss - diff --git a/src/il_representations/algos/optimizers.py b/src/il_representations/algos/optimizers.py index 590e7f97..31999b4a 100644 --- a/src/il_representations/algos/optimizers.py +++ b/src/il_representations/algos/optimizers.py @@ -26,22 +26,28 @@ class LARS(Optimizer): >>> loss_fn(model(input), target).backward() >>> optimizer.step() """ - def __init__(self, params, lr=required, momentum=.9, - weight_decay=.0005, eta=0.001, max_epoch=200): + def __init__(self, + params, + lr=required, + momentum=.9, + weight_decay=.0005, + eta=0.001, + max_epoch=200): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}" - .format(weight_decay)) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if eta < 0.0: raise ValueError("Invalid LARS coefficient value: {}".format(eta)) self.epoch = 0 - defaults = dict(lr=lr, momentum=momentum, + defaults = dict(lr=lr, + momentum=momentum, weight_decay=weight_decay, - eta=eta, max_epoch=max_epoch) + eta=eta, + max_epoch=max_epoch) super(LARS, self).__init__(params, defaults) def step(self, epoch=None, closure=None): @@ -79,7 +85,7 @@ def step(self, epoch=None, closure=None): grad_norm = torch.norm(d_p) # Global LR computed on polynomial decay schedule - decay = (1 - float(epoch) / max_epoch) ** 2 + decay = (1 - float(epoch) / max_epoch)**2 global_lr = lr * decay # Compute local learning rate for this layer @@ -91,7 +97,7 @@ def step(self, epoch=None, closure=None): if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = \ - torch.zeros_like(p.data) + torch.zeros_like(p.data) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data) diff --git a/src/il_representations/algos/pair_constructors.py b/src/il_representations/algos/pair_constructors.py index 43c6913c..d519c8f9 100644 --- a/src/il_representations/algos/pair_constructors.py +++ b/src/il_representations/algos/pair_constructors.py @@ -1,26 +1,28 @@ -import numpy as np -from abc import ABC, abstractmethod - """ -Pair Constructors turn a basic trajectory dataset into a dataset of `context`, `target`, and `extra_context` data -elements, along with a metadata tensor containing the trajectory ID and timestep ID for each element in the dataset. -The `context` element is conceptually thought of as the element you're using to do prediction, whereas the `target` -is the ground truth or "positive" we're trying to predict from the context, though this prediction framework is +Pair Constructors turn a basic trajectory dataset into a dataset of `context`, +`target`, and `extra_context` data elements, along with a metadata tensor +containing the trajectory ID and timestep ID for each element in the dataset. +The `context` element is conceptually thought of as the element you're using to +do prediction, whereas the `target` is the ground truth or "positive" we're +trying to predict from the context, though this prediction framework is admittedly a somewhat fuzzy match to the actual variety of techniques. -- In temporal contrastive loss settings, context is generally the element at position (t), and target the element at -position (t+k) -- In pure-augmentation contrastive loss settings, context and target are the same element (which will be augmented -in different ways) -- In a VAE, context and target are also the same element. Context will be mapped into a representation and then decoded -back out, whereas the target will "tag along" as ground truth pixels needed to calculate the loss. -- In Dynamics modeling, context is the current state at time (t), target is the state at time (t+1) and extra context -is the action taken at time (t) +- In temporal contrastive loss settings, context is generally the element at + position (t), and target the element at position (t+k) +- In pure-augmentation contrastive loss settings, context and target are the + same element (which will be augmented in different ways) +- In a VAE, context and target are also the same element. Context will be + mapped into a representation and then decoded back out, whereas the target + will "tag along" as ground truth pixels needed to calculate the loss. +- In Dynamics modeling, context is the current state at time (t), target is the + state at time (t+1) and extra context is the action taken at time (t) """ +from abc import ABC, abstractmethod +import numpy as np -class TargetPairConstructor(ABC): +class TargetPairConstructor(ABC): @abstractmethod def __call__(self, data_dict): pass @@ -29,10 +31,16 @@ def __call__(self, data_dict): class IdentityPairConstructor(TargetPairConstructor): def __call__(self, data_dict): obs, actions, dones = data_dict['obs'], data_dict['acts'], data_dict['dones'] + del actions # unused dataset = [] trajectory_ind = timestep = 0 for i in range(len(dones)): - dataset.append({'context': obs[i], 'target': obs[i], 'extra_context': [], 'traj_ts_ids': [trajectory_ind, timestep]}) + dataset.append({ + 'context': obs[i], + 'target': obs[i], + 'extra_context': [], + 'traj_ts_ids': [trajectory_ind, timestep] + }) timestep += 1 if dones[i]: trajectory_ind += 1 @@ -52,7 +60,7 @@ def __call__(self, data_dict): trajectory_ind = timestep = 0 i = 0 while i < len(dones) - self.k: - if np.any(dones[i:i+self.k]): + if np.any(dones[i:i + self.k]): # If dones[i] is true, next obs is from new trajectory, skip # Also skip if we are i += self.k @@ -60,16 +68,27 @@ def __call__(self, data_dict): timestep = 0 continue if self.mode is None: - dataset.append({'context': obs[i], 'target': obs[i + self.k], - 'extra_context': [], 'traj_ts_ids': [trajectory_ind, timestep]}) + dataset.append({ + 'context': obs[i], + 'target': obs[i + self.k], + 'extra_context': [], + 'traj_ts_ids': [trajectory_ind, timestep] + }) elif self.mode == 'dynamics': - dataset.append({'context': obs[i], 'target': obs[i + self.k], - 'extra_context': actions[i:i+self.k], 'traj_ts_ids': [trajectory_ind, timestep]}) + dataset.append({ + 'context': obs[i], + 'target': obs[i + self.k], + 'extra_context': actions[i:i + self.k], + 'traj_ts_ids': [trajectory_ind, timestep] + }) elif self.mode == 'inverse_dynamics': - dataset.append({'context': obs[i], 'target': actions[i:i+self.k], - 'extra_context': obs[i + self.k], 'traj_ts_ids': [trajectory_ind, timestep]}) + dataset.append({ + 'context': obs[i], + 'target': actions[i:i + self.k], + 'extra_context': obs[i + self.k], + 'traj_ts_ids': [trajectory_ind, timestep] + }) timestep += 1 i += 1 return dataset - diff --git a/src/il_representations/algos/representation_learner.py b/src/il_representations/algos/representation_learner.py index b6717e32..83d63b2b 100644 --- a/src/il_representations/algos/representation_learner.py +++ b/src/il_representations/algos/representation_learner.py @@ -1,33 +1,31 @@ +import inspect import os -import torch + +import imitation.util.logger as logger from stable_baselines3.common.preprocessing import preprocess_obs from stable_baselines3.common.utils import get_device +import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter -from imitation.augment import StandardAugmentations -from il_representations.algos.batch_extenders import IdentityBatchExtender, QueueBatchExtender -from il_representations.algos.base_learner import BaseEnvironmentLearner -from il_representations.algos.utils import AverageMeter, Logger -from il_representations.algos.augmenters import AugmentContextOnly -from gym.spaces import Box -import torch -import inspect -import imitation.util.logger as logger +from il_representations.algos.augmenters import AugmentContextOnly +from il_representations.algos.base_learner import BaseEnvironmentLearner +from il_representations.algos.batch_extenders import IdentityBatchExtender, QueueBatchExtender +from il_representations.algos.utils import AverageMeter -DEFAULT_HARDCODED_PARAMS = ['encoder', 'decoder', 'loss_calculator', 'augmenter', 'target_pair_constructor'] +DEFAULT_HARDCODED_PARAMS = [ + 'encoder', 'decoder', 'loss_calculator', 'augmenter', 'target_pair_constructor' +] def get_default_args(func): signature = inspect.signature(func) return { k: v.default - for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty + for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty } - def to_dict(kwargs_element): # To get around not being able to have empty dicts as default values if kwargs_element is None: @@ -37,8 +35,13 @@ def to_dict(kwargs_element): class RepresentationLearner(BaseEnvironmentLearner): - def __init__(self, env, *, - log_dir, encoder, decoder, loss_calculator, + def __init__(self, + env, + *, + log_dir, + encoder, + decoder, + loss_calculator, target_pair_constructor, augmenter=AugmentContextOnly, batch_extender=IdentityBatchExtender, @@ -66,31 +69,35 @@ def __init__(self, env, *, self.log_dir = log_dir logger.configure(log_dir, ["stdout", "tensorboard"]) - self.encoder_checkpoints_path = os.path.join(self.log_dir, 'checkpoints', 'representation_encoder') + self.encoder_checkpoints_path = os.path.join(self.log_dir, 'checkpoints', + 'representation_encoder') os.makedirs(self.encoder_checkpoints_path, exist_ok=True) self.decoder_checkpoints_path = os.path.join(self.log_dir, 'checkpoints', 'loss_decoder') os.makedirs(self.decoder_checkpoints_path, exist_ok=True) - self.device = get_device("auto" if device is None else device) self.shuffle_batches = shuffle_batches self.batch_size = batch_size self.preprocess_extra_context = preprocess_extra_context self.save_interval = save_interval - #self._make_channels_first() + # self._make_channels_first() self.unit_test_max_train_steps = unit_test_max_train_steps if projection_dim is None: - # If no projection_dim is specified, it will be assumed to be the same as representation_dim - # This doesn't have any meaningful effect unless you specify a projection head. + # If no projection_dim is specified, it will be assumed to be the + # same as representation_dim This doesn't have any meaningful + # effect unless you specify a projection head. projection_dim = representation_dim self.augmenter = augmenter(**augmenter_kwargs) - self.target_pair_constructor = target_pair_constructor(**to_dict(target_pair_constructor_kwargs)) + self.target_pair_constructor = target_pair_constructor( + **to_dict(target_pair_constructor_kwargs)) encoder_kwargs = to_dict(encoder_kwargs) - self.encoder = encoder(self.observation_space, representation_dim, **encoder_kwargs).to(self.device) - self.decoder = decoder(representation_dim, projection_dim, **to_dict(decoder_kwargs)).to(self.device) + self.encoder = encoder(self.observation_space, representation_dim, + **encoder_kwargs).to(self.device) + self.decoder = decoder(representation_dim, projection_dim, + **to_dict(decoder_kwargs)).to(self.device) if batch_extender is QueueBatchExtender: # TODO maybe clean this up? @@ -115,10 +122,14 @@ def __init__(self, env, *, self.scheduler = scheduler(self.optimizer, **to_dict(scheduler_kwargs)) else: self.scheduler = None - self.writer = SummaryWriter(log_dir=os.path.join(log_dir, 'contrastive_tf_logs'), flush_secs=15) - - def validate_and_update_kwargs(self, user_kwargs, kwargs_updates=None, - hardcoded_params=None, params_cleaned=False): + self.writer = SummaryWriter(log_dir=os.path.join(log_dir, 'contrastive_tf_logs'), + flush_secs=15) + + def validate_and_update_kwargs(self, + user_kwargs, + kwargs_updates=None, + hardcoded_params=None, + params_cleaned=False): # return a copy instead of updating in-place to avoid inconsistent state # after a failed update user_kwargs_copy = user_kwargs.copy() @@ -131,8 +142,9 @@ def validate_and_update_kwargs(self, user_kwargs, kwargs_updates=None, if hardcoded_param not in user_kwargs_copy: continue if user_kwargs_copy[hardcoded_param] != default_args[hardcoded_param]: - raise ValueError(f"You passed in a non-default value for parameter {hardcoded_param} " - f"hardcoded by {self.__class__.__name__}") + raise ValueError( + f"You passed in a non-default value for parameter {hardcoded_param} " + f"hardcoded by {self.__class__.__name__}") del user_kwargs_copy[hardcoded_param] if kwargs_updates is not None: @@ -140,9 +152,10 @@ def validate_and_update_kwargs(self, user_kwargs, kwargs_updates=None, raise TypeError("kwargs_updates must be passed in in the form of a dict ") for kwarg_update_key in kwargs_updates.keys(): if isinstance(user_kwargs_copy[kwarg_update_key], dict): - user_kwargs_copy[kwarg_update_key] = self.validate_and_update_kwargs(user_kwargs_copy[kwarg_update_key], - kwargs_updates[kwarg_update_key], - params_cleaned=True) + user_kwargs_copy[kwarg_update_key] = self.validate_and_update_kwargs( + user_kwargs_copy[kwarg_update_key], + kwargs_updates[kwarg_update_key], + params_cleaned=True) else: user_kwargs_copy[kwarg_update_key] = kwargs_updates[kwarg_update_key] return user_kwargs_copy @@ -168,10 +181,9 @@ def _prep_tensors(self, tensors_or_arrays): batch_tensor.shape[1] < batch_tensor.shape[2] \ and batch_tensor.shape[1] < batch_tensor.shape[3] if not is_nchw_heuristic: - raise ValueError( - f"Batch tensor axes {batch_tensor.shape} do not look " - "like they're in NCHW order. Did you accidentally pass in " - "a channels-last tensor?") + raise ValueError(f"Batch tensor axes {batch_tensor.shape} do not look " + "like they're in NCHW order. Did you accidentally pass in " + "a channels-last tensor?") if torch.is_floating_point(batch_tensor): # cast double to float for perf reasons (also drops half-precision) dtype = torch.float @@ -183,8 +195,7 @@ def _prep_tensors(self, tensors_or_arrays): def _preprocess(self, input_data): # SB will normalize to [0,1] - return preprocess_obs(input_data, self.observation_space, - normalize_images=True) + return preprocess_obs(input_data, self.observation_space, normalize_images=True) def _preprocess_extra_context(self, extra_context): if extra_context is None or not self.preprocess_extra_context: @@ -194,9 +205,11 @@ def _preprocess_extra_context(self, extra_context): # TODO maybe make static? def unpack_batch(self, batch): """ - :param batch: A batch that may contain a numpy array of extra context, but may also simply have an - empty list as a placeholder value for the `extra_context` key. If the latter, return None for extra_context, - rather than an empty list (Torch data loaders can only work with lists and arrays, not None types) + :param batch: A batch that may contain a numpy array of extra context, + but may also simply have an empty list as a placeholder value for + the `extra_context` key. If the latter, return None for + extra_context, rather than an empty list (Torch data loaders can + only work with lists and arrays, not None types) :return: """ if len(batch['extra_context']) == 0: @@ -231,29 +244,37 @@ def learn(self, dataset, training_epochs): contexts, targets = self._prep_tensors(contexts), self._prep_tensors(targets) extra_context = self._prep_tensors(extra_context) traj_ts_info = self._prep_tensors(traj_ts_info) - # Note: preprocessing might be better to do on CPU if, in future, we can parallelize doing so + # Note: preprocessing might be better to do on CPU if, in + # future, we can parallelize doing so contexts, targets = self._preprocess(contexts), self._preprocess(targets) contexts, targets = self.augmenter(contexts, targets) extra_context = self._preprocess_extra_context(extra_context) - # These will typically just use the forward() function for the encoder, but can optionally - # use a specific encode_context and encode_target if one is implemented + # These will typically just use the forward() function for the + # encoder, but can optionally use a specific encode_context and + # encode_target if one is implemented encoded_contexts = self.encoder.encode_context(contexts, traj_ts_info) encoded_targets = self.encoder.encode_target(targets, traj_ts_info) # Typically the identity function extra_context = self.encoder.encode_extra_context(extra_context, traj_ts_info) - # Use an algorithm-specific decoder to "decode" the representations into a loss-compatible tensor - # As with encode, these will typically just use forward() - decoded_contexts = self.decoder.decode_context(encoded_contexts, traj_ts_info, extra_context) - decoded_targets = self.decoder.decode_target(encoded_targets, traj_ts_info, extra_context) + # Use an algorithm-specific decoder to "decode" the + # representations into a loss-compatible tensor As with encode, + # these will typically just use forward() + decoded_contexts = self.decoder.decode_context(encoded_contexts, traj_ts_info, + extra_context) + decoded_targets = self.decoder.decode_target(encoded_targets, traj_ts_info, + extra_context) - # Optionally add to the batch before loss. By default, this is an identity operation, but - # can also implement momentum queue logic - decoded_contexts, decoded_targets = self.batch_extender(decoded_contexts, decoded_targets) + # Optionally add to the batch before loss. By default, this is + # an identity operation, but can also implement momentum queue + # logic + decoded_contexts, decoded_targets = self.batch_extender( + decoded_contexts, decoded_targets) - # Use an algorithm-specific loss function. Typically this only requires decoded_contexts and - # decoded_targets, but VAE requires encoded_contexts, so we pass it in here + # Use an algorithm-specific loss function. Typically this only + # requires decoded_contexts and decoded_targets, but VAE + # requires encoded_contexts, so we pass it in here loss = self.loss_calculator(decoded_contexts, decoded_targets, encoded_contexts) @@ -278,5 +299,7 @@ def learn(self, dataset, training_epochs): self.encoder.train(False) self.decoder.train(False) if epoch % self.save_interval == 0: - torch.save(self.encoder, os.path.join(self.encoder_checkpoints_path, f'{epoch}_epochs.ckpt')) - torch.save(self.decoder, os.path.join(self.decoder_checkpoints_path, f'{epoch}_epochs.ckpt')) + torch.save(self.encoder, + os.path.join(self.encoder_checkpoints_path, f'{epoch}_epochs.ckpt')) + torch.save(self.decoder, + os.path.join(self.decoder_checkpoints_path, f'{epoch}_epochs.ckpt')) diff --git a/src/il_representations/algos/utils.py b/src/il_representations/algos/utils.py index 131b382b..870ab92c 100644 --- a/src/il_representations/algos/utils.py +++ b/src/il_representations/algos/utils.py @@ -1,31 +1,30 @@ -import torch -import numpy as np -import random -import gym +from datetime import datetime import math -from torch.optim.lr_scheduler import _LRScheduler +from numbers import Number import os -import torch -import numpy as np -import matplotlib.pyplot as plt -from datetime import datetime +import random + from PIL import Image -from numbers import Number import cv2 +import gym +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.optim.lr_scheduler import _LRScheduler def independent_multivariate_normal(loc, scale): batch_dim = loc.shape[0] if isinstance(scale, Number): # If the scale was passed in as a scalar, convert it to the same shape as loc - scale = torch.ones(loc.shape, device=loc.device)*scale + scale = torch.ones(loc.shape, device=loc.device) * scale # Turn each -length vector in the batch into a diagonal matrix, because we want an # independent multivariate normal merged_covariance_matrix = torch.stack([torch.diag(scale[i]) for i in range(batch_dim)]) - return torch.distributions.MultivariateNormal( - loc=loc, - covariance_matrix=merged_covariance_matrix.to(loc.device)) + return torch.distributions.MultivariateNormal(loc=loc, + covariance_matrix=merged_covariance_matrix.to( + loc.device)) def add_noise(state, noise_std_dev): @@ -51,7 +50,8 @@ def show_plt_image(img): plt.show() -# TODO: Have the calls to savefig below save to the log directory (or at least make the output directory in case it doesn't exist) +# TODO: Have the calls to savefig below save to the log directory (or at least +# make the output directory in case it doesn't exist) def plot(arr, env_id, gap=1): fig = plt.figure() x = np.arange(len(arr.shape[1])) * gap @@ -80,7 +80,8 @@ def plot_single(arr, label, msg): def save_model(model, env_id, save_path): os.makedirs(save_path, exist_ok=True) - torch.save(model.state_dict(), os.path.join(save_path, f'[{time_now(datetime.now())}]{env_id}.pth')) + torch.save(model.state_dict(), + os.path.join(save_path, f'[{time_now(datetime.now())}]{env_id}.pth')) def time_now(n): @@ -90,7 +91,7 @@ def time_now(n): class Logger: def __init__(self, log_dir): - self.file = os.path.join(log_dir, f'train_log.txt') + self.file = os.path.join(log_dir, 'train_log.txt') def log(self, msg): t = datetime.now() @@ -112,14 +113,17 @@ def get_lr(self): if self.warmup_epoch > 0: if self.last_epoch <= self.warmup_epoch: return [base_lr / self.warmup_epoch * self.last_epoch for base_lr in self.base_lrs] - if ((self.last_epoch - self.warmup_epoch) - 1 - (self.T_max - self.warmup_epoch)) % (2 * (self.T_max - self.warmup_epoch)) == 0: - return [group['lr'] + (base_lr - self.eta_min) * - (1 - math.cos(math.pi / (self.T_max - self.warmup_epoch))) / 2 - for base_lr, group in - zip(self.base_lrs, self.optimizer.param_groups)] + cur_prog = self.last_epoch - self.warmup_epoch + max_prog = self.T_max - self.warmup_epoch + if (cur_prog - 1 - max_prog) % (2 * max_prog) == 0: + return [ + group['lr'] + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / (self.T_max - self.warmup_epoch))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] else: - return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epoch) / (self.T_max - self.warmup_epoch))) / - (1 + math.cos(math.pi * ((self.last_epoch - self.warmup_epoch) - 1) / (self.T_max - self.warmup_epoch))) * + return [(1 + math.cos(math.pi * cur_prog / max_prog)) / + (1 + math.cos(math.pi * (cur_prog - 1) / max_prog)) * (group['lr'] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups] @@ -138,7 +142,7 @@ def set_global_seeds(seed): gym.spaces.prng.seed(seed) -def accuracy(output, target, topk=(1,)): +def accuracy(output, target, topk=(1, )): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) diff --git a/src/il_representations/data.py b/src/il_representations/data.py index ce7f7c80..10d2d999 100644 --- a/src/il_representations/data.py +++ b/src/il_representations/data.py @@ -13,16 +13,13 @@ class TransitionsMinimalDataset(il_datasets.Dataset): def __init__(self, data_map): req_keys = {'obs', 'acts', 'next_obs', 'dones'} assert req_keys <= data_map.keys() - self.dict_dataset = il_datasets.RandomDictDataset( - {k: data_map[k] - for k in req_keys}) + self.dict_dataset = il_datasets.RandomDictDataset({k: data_map[k] for k in req_keys}) def sample(self, n_samples): dict_samples = self.dict_dataset.sample(n_samples) # we don't have infos dicts, so we insert some fake ones to make # TransitionsMinimal happy - dummy_infos = np.asarray([{} for _ in range(n_samples)], - dtype='object') + dummy_infos = np.asarray([{} for _ in range(n_samples)], dtype='object') result = il_types.Transitions(infos=dummy_infos, **dict_samples) assert len(result) == n_samples return result diff --git a/src/il_representations/envs/atari_envs.py b/src/il_representations/envs/atari_envs.py index bd0038cb..24d6faee 100644 --- a/src/il_representations/envs/atari_envs.py +++ b/src/il_representations/envs/atari_envs.py @@ -1,13 +1,13 @@ """Utilities for working with Atari environments and demonstrations.""" -import numpy as np import random +import numpy as np + from il_representations.envs.config import benchmark_ingredient @benchmark_ingredient.capture -def load_dataset_atari(atari_env_id, atari_demo_paths, n_traj, - chans_first=True): +def load_dataset_atari(atari_env_id, atari_demo_paths, n_traj, chans_first=True): # load trajectories from disk full_rollouts_path = atari_demo_paths[atari_env_id] trajs_or_file = np.load(full_rollouts_path, allow_pickle=True) @@ -33,10 +33,7 @@ def load_dataset_atari(atari_env_id, atari_demo_paths, n_traj, merged_trajectories['next_obs'] += traj['states'][1:] merged_trajectories['acts'] += traj['actions'][:-1] merged_trajectories['dones'] += traj['dones'][:-1] - dataset_dict = { - key: np.stack(values, axis=0) - for key, values in merged_trajectories.items() - } + dataset_dict = {key: np.stack(values, axis=0) for key, values in merged_trajectories.items()} if chans_first: # In Gym Atari envs, channels are last; chans_first will transpose data diff --git a/src/il_representations/envs/auto.py b/src/il_representations/envs/auto.py index 1b1033f9..b0f2c022 100644 --- a/src/il_representations/envs/auto.py +++ b/src/il_representations/envs/auto.py @@ -3,15 +3,13 @@ from imitation.util.util import make_vec_env from stable_baselines3.common.atari_wrappers import AtariWrapper -from stable_baselines3.common.vec_env import (DummyVecEnv, SubprocVecEnv, - VecFrameStack, VecTransposeImage) +from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage from il_representations.algos.augmenters import ColorSpace from il_representations.envs.atari_envs import load_dataset_atari from il_representations.envs.config import benchmark_ingredient from il_representations.envs.dm_control_envs import load_dataset_dm_control -from il_representations.envs.magical_envs import (get_env_name_magical, - load_dataset_magical) +from il_representations.envs.magical_envs import get_env_name_magical, load_dataset_magical ERROR_MESSAGE = "no support for benchmark_name={benchmark['benchmark_name']!r}" @@ -30,8 +28,7 @@ def load_dataset(benchmark_name): @benchmark_ingredient.capture -def get_gym_env_name(benchmark_name, atari_env_id, dm_control_full_env_names, - dm_control_env): +def get_gym_env_name(benchmark_name, atari_env_id, dm_control_full_env_names, dm_control_env): if benchmark_name == 'magical': return get_env_name_magical() elif benchmark_name == 'dm_control': @@ -42,15 +39,13 @@ def get_gym_env_name(benchmark_name, atari_env_id, dm_control_full_env_names, @benchmark_ingredient.capture -def load_vec_env(benchmark_name, atari_env_id, dm_control_full_env_names, - dm_control_env, venv_parallel, n_envs): +def load_vec_env(benchmark_name, atari_env_id, dm_control_full_env_names, dm_control_env, + venv_parallel, n_envs): """Create a vec env for the selected benchmark task and wrap it with any necessary wrappers.""" gym_env_name = get_gym_env_name() if benchmark_name in ('magical', 'dm_control'): - return make_vec_env(gym_env_name, - n_envs=n_envs, - parallel=venv_parallel) + return make_vec_env(gym_env_name, n_envs=n_envs, parallel=venv_parallel) elif benchmark_name == 'atari': assert not venv_parallel, "currently does not support parallel kwarg" raw_atari_env = make_vec_env(gym_env_name, diff --git a/src/il_representations/envs/config.py b/src/il_representations/envs/config.py index 2ddfa0f2..f97fa137 100644 --- a/src/il_representations/envs/config.py +++ b/src/il_representations/envs/config.py @@ -75,10 +75,8 @@ def bench_defaults(): atari_env_id = 'PongNoFrameskip-v4' atari_demo_paths = { - 'BreakoutNoFrameskip-v4': - "data/atari/BreakoutNoFrameskip-v4_rollouts_500_ts_100_traj.npz", - 'PongNoFrameskip-v4': - "data/atari/PongNoFrameskip-v4_rollouts_500_ts_100_traj.npz", + 'BreakoutNoFrameskip-v4': "data/atari/BreakoutNoFrameskip-v4_rollouts_500_ts_100_traj.npz", + 'PongNoFrameskip-v4': "data/atari/PongNoFrameskip-v4_rollouts_500_ts_100_traj.npz", } _ = locals() diff --git a/src/il_representations/envs/dm_control_envs.py b/src/il_representations/envs/dm_control_envs.py index 1441632a..45aea8e1 100644 --- a/src/il_representations/envs/dm_control_envs.py +++ b/src/il_representations/envs/dm_control_envs.py @@ -23,13 +23,12 @@ def register_dmc_envs(): return _REGISTERED = True - common = dict( - seed=0, - visualize_reward=False, - from_pixels=True, - height=IMAGE_SIZE, - width=IMAGE_SIZE, - channels_first=True) + common = dict(seed=0, + visualize_reward=False, + from_pixels=True, + height=IMAGE_SIZE, + width=IMAGE_SIZE, + channels_first=True) def entry_point(**kwargs): # add in common kwargs @@ -38,43 +37,31 @@ def entry_point(**kwargs): # frame skip 2 gym.register('DMC-Finger-Spin-v0', entry_point=entry_point, - kwargs=dict(domain_name='finger', - task_name='spin', - frame_skip=2)) + kwargs=dict(domain_name='finger', task_name='spin', frame_skip=2)) # frame skip 4 gym.register('DMC-Cheetah-Run-v0', entry_point=entry_point, - kwargs=dict(domain_name='cheetah', - task_name='run', - frame_skip=4)) + kwargs=dict(domain_name='cheetah', task_name='run', frame_skip=4)) # frame skip 8 gym.register('DMC-Walker-Walk-v0', entry_point=entry_point, - kwargs=dict(domain_name='walker', - task_name='walk', - frame_skip=8)) + kwargs=dict(domain_name='walker', task_name='walk', frame_skip=8)) gym.register('DMC-Cartpole-Swingup-v0', entry_point=entry_point, - kwargs=dict(domain_name='cartpole', - task_name='swingup', - frame_skip=8)) + kwargs=dict(domain_name='cartpole', task_name='swingup', frame_skip=8)) gym.register('DMC-Reacher-Easy-v0', entry_point=entry_point, - kwargs=dict(domain_name='reacher', - task_name='easy', - frame_skip=8)) + kwargs=dict(domain_name='reacher', task_name='easy', frame_skip=8)) gym.register('DMC-Ball-In-Cup-Catch-v0', entry_point=entry_point, - kwargs=dict(domain_name='ball_in_cup', - task_name='catch', - frame_skip=8)) + kwargs=dict(domain_name='ball_in_cup', task_name='catch', frame_skip=8)) @benchmark_ingredient.capture -def load_dataset_dm_control(dm_control_env, dm_control_full_env_names, - dm_control_demo_patterns, n_traj): +def load_dataset_dm_control(dm_control_env, dm_control_full_env_names, dm_control_demo_patterns, + n_traj): # load data from all relevant paths data_pattern = dm_control_demo_patterns[dm_control_env] data_paths = glob.glob(os.path.expanduser(data_pattern)) @@ -94,22 +81,15 @@ def load_dataset_dm_control(dm_control_env, dm_control_full_env_names, # for each trajectory of length T (not including final observation), we # create an array of `dones` consisting of T-1 False values and one # terminal True value - np.array([False] * (len(t.acts) - 1) + [True], dtype='bool') - for t in loaded_trajs + np.array([False] * (len(t.acts) - 1) + [True], dtype='bool') for t in loaded_trajs ] dataset_dict = { - 'obs': - np.concatenate([t.obs[:-1] for t in loaded_trajs], axis=0), - 'acts': - np.concatenate([t.acts for t in loaded_trajs], axis=0), - 'next_obs': - np.concatenate([t.obs[1:] for t in loaded_trajs], axis=0), - 'infos': - np.concatenate([t.infos for t in loaded_trajs], axis=0), - 'rews': - np.concatenate([t.rews for t in loaded_trajs], axis=0), - 'dones': - np.concatenate(dones_lists, axis=0), + 'obs': np.concatenate([t.obs[:-1] for t in loaded_trajs], axis=0), + 'acts': np.concatenate([t.acts for t in loaded_trajs], axis=0), + 'next_obs': np.concatenate([t.obs[1:] for t in loaded_trajs], axis=0), + 'infos': np.concatenate([t.infos for t in loaded_trajs], axis=0), + 'rews': np.concatenate([t.rews for t in loaded_trajs], axis=0), + 'dones': np.concatenate(dones_lists, axis=0), } return dataset_dict diff --git a/src/il_representations/envs/magical_envs.py b/src/il_representations/envs/magical_envs.py index 3e03f372..e59551b0 100644 --- a/src/il_representations/envs/magical_envs.py +++ b/src/il_representations/envs/magical_envs.py @@ -20,8 +20,8 @@ def load_data( - pickle_paths: List[str], - preprocessor_name: str, + pickle_paths: List[str], + preprocessor_name: str, ) -> Tuple[str, il_datasets.Dataset]: """Load MAGICAL data from pickle files.""" @@ -37,9 +37,8 @@ def load_data( env_name = new_env_name else: if env_name != new_env_name: - raise ValueError( - f"supplied trajectory paths contain demos for multiple " - f"environments: {env_name}, {new_env_name} ") + raise ValueError(f"supplied trajectory paths contain demos for multiple " + f"environments: {env_name}, {new_env_name} ") demo_trajectories.append(demo_dict['trajectory']) @@ -54,9 +53,7 @@ def load_data( # the new preprocessor name. if preprocessor_name: demo_trajectories = saved_trajectories.preprocess_demos_with_wrapper( - demo_trajectories, - orig_env_name=env_name, - preproc_name=preprocessor_name) + demo_trajectories, orig_env_name=env_name, preproc_name=preprocessor_name) # Finally we build a DictDataset for actions and observations. dataset_dict = collections.defaultdict(list) @@ -93,29 +90,21 @@ def load_data( @benchmark_ingredient.capture def get_env_name_magical(magical_env_prefix, magical_preproc): orig_env_name = magical_env_prefix + '-Demo-v0' - gym_env_name = saved_trajectories.splice_in_preproc_name( - orig_env_name, magical_preproc) + gym_env_name = saved_trajectories.splice_in_preproc_name(orig_env_name, magical_preproc) return gym_env_name @benchmark_ingredient.capture -def load_dataset_magical(magical_demo_dirs, magical_env_prefix, - magical_preproc, n_traj): +def load_dataset_magical(magical_demo_dirs, magical_env_prefix, magical_preproc, n_traj): demo_dir = magical_demo_dirs[magical_env_prefix] - logging.info( - f"Loading trajectory data for '{magical_env_prefix}' from " - f"'{demo_dir}'") - demo_paths = [ - os.path.join(demo_dir, f) for f in os.listdir(demo_dir) - if f.endswith('.pkl.gz') - ] + logging.info(f"Loading trajectory data for '{magical_env_prefix}' from " f"'{demo_dir}'") + demo_paths = [os.path.join(demo_dir, f) for f in os.listdir(demo_dir) if f.endswith('.pkl.gz')] if not demo_paths: raise IOError(f"Could not find any demo pickle files in '{demo_dir}'") random.shuffle(demo_paths) if n_traj is not None: demo_paths = demo_paths[:n_traj] - dataset_dict, loaded_env_name = load_data( - demo_paths, preprocessor_name=magical_preproc) + dataset_dict, loaded_env_name = load_data(demo_paths, preprocessor_name=magical_preproc) gym_env_name = get_env_name_magical() assert loaded_env_name.startswith(gym_env_name.rsplit('-')[0]) return dataset_dict @@ -145,11 +134,11 @@ def obtain_scores(self, env_name, venv_parallel, n_envs): seed=self.seed, parallel=venv_parallel) rng = np.random.RandomState(self.seed) - trajectories = il_rollout.generate_trajectories( - self.policy, - vec_env_chans_last, - sample_until=il_rollout.min_episodes(self.n_rollouts), - rng=rng) + trajectories = il_rollout.generate_trajectories(self.policy, + vec_env_chans_last, + sample_until=il_rollout.min_episodes( + self.n_rollouts), + rng=rng) scores = [] for trajectory in trajectories[:self.n_rollouts]: scores.append(trajectory.infos[-1]['eval_score']) diff --git a/src/il_representations/il/disc_rew_nets.py b/src/il_representations/il/disc_rew_nets.py index 65da6557..a2b99067 100644 --- a/src/il_representations/il/disc_rew_nets.py +++ b/src/il_representations/il/disc_rew_nets.py @@ -1,9 +1,10 @@ """Custom discriminator/reward networks for `imitation`.""" +from stable_baselines3.common.preprocessing import get_flattened_obs_dim import torch as th from torch import nn -from stable_baselines3.common.preprocessing import get_flattened_obs_dim -from il_representations.algos.encoders import compute_rep_shape_encoder, DeterministicEncoder + +from il_representations.algos.encoders import DeterministicEncoder, compute_rep_shape_encoder class ImageDiscrimNet(nn.Module): @@ -32,7 +33,8 @@ def __init__(self, encoder_cls = DeterministicEncoder if encoder_kwargs is None: encoder_kwargs = {} - self.obs_encoder = encoder_cls(obs_space=observation_space, representation_dim=fc_dim, + self.obs_encoder = encoder_cls(obs_space=observation_space, + representation_dim=fc_dim, **encoder_kwargs) obs_out_dim = fc_dim diff --git a/src/il_representations/policy_interfacing.py b/src/il_representations/policy_interfacing.py index 07422112..fc707123 100644 --- a/src/il_representations/policy_interfacing.py +++ b/src/il_representations/policy_interfacing.py @@ -1,11 +1,18 @@ -import torch from stable_baselines3.common.policies import BaseFeaturesExtractor +import torch + from il_representations.algos.encoders import compute_rep_shape_encoder class EncoderFeatureExtractor(BaseFeaturesExtractor): - def __init__(self, observation_space, features_dim=None, encoder=None, encoder_path=None, finetune=True): - # Allow user to either pass in an existing encoder, or a path from which to load a pickled encoder + def __init__(self, + observation_space, + features_dim=None, + encoder=None, + encoder_path=None, + finetune=True): + # Allow user to either pass in an existing encoder, or a path from + # which to load a pickled encoder assert encoder is not None or encoder_path is not None, \ "You must pass in either an encoder object or a path to an encoder" assert not (encoder is not None and encoder_path is not None), \ @@ -34,8 +41,15 @@ def forward(self, observations): class EncoderSimplePolicyHead(EncoderFeatureExtractor): - # Not actually a FeatureExtractor for SB use, but a very simple Policy for use in Cynthia's BC code - def __init__(self, observation_space, features_dim, action_size, encoder=None, encoder_path=None, finetune=True): + # Not actually a FeatureExtractor for SB use, but a very simple Policy for + # use in Cynthia's BC code + def __init__(self, + observation_space, + features_dim, + action_size, + encoder=None, + encoder_path=None, + finetune=True): super().__init__(observation_space, features_dim, encoder, encoder_path, finetune) self.action_layer = torch.nn.Linear(encoder.representation_dim, action_size) self.softmax = torch.nn.Softmax(dim=-1) diff --git a/src/il_representations/scripts/il_test.py b/src/il_representations/scripts/il_test.py index 7cef0877..b4c27c9d 100644 --- a/src/il_representations/scripts/il_test.py +++ b/src/il_representations/scripts/il_test.py @@ -5,8 +5,8 @@ import logging import tempfile -import imitation.util.logger as imitation_logger import imitation.data.rollout as il_rollout +import imitation.util.logger as imitation_logger import numpy as np from sacred import Experiment from sacred.observers import FileStorageObserver @@ -14,8 +14,8 @@ import torch as th from il_representations.algos.utils import set_global_seeds -from il_representations.envs.config import benchmark_ingredient from il_representations.envs import auto +from il_representations.envs.config import benchmark_ingredient il_test_ex = Experiment('il_test', ingredients=[benchmark_ingredient]) @@ -43,8 +43,7 @@ def run(policy_path, benchmark, seed, n_rollouts, device_name, run_id): imitation_logger.configure(log_dir, ["stdout", "tensorboard"]) if policy_path is None: - raise ValueError( - "must pass a string-valued policy_path to this command") + raise ValueError("must pass a string-valued policy_path to this command") policy = th.load(policy_path) device = get_device(device_name) @@ -65,8 +64,7 @@ def run(policy_path, benchmark, seed, n_rollouts, device_name, run_id): ) eval_data_frame = eval_protocol.do_eval(verbose=False) # display to stdout - logging.info("Evaluation finished, results:\n" + - eval_data_frame.to_string()) + logging.info("Evaluation finished, results:\n" + eval_data_frame.to_string()) final_stats_dict = { 'demo_env_name': demo_env_name, 'policy_path': policy_path, @@ -79,8 +77,7 @@ def run(policy_path, benchmark, seed, n_rollouts, device_name, run_id): 'return_mean': eval_data_frame['mean_score'].mean(), } - elif (benchmark['benchmark_name'] == 'dm_control' - or benchmark['benchmark_name'] == 'atari'): + elif (benchmark['benchmark_name'] == 'dm_control' or benchmark['benchmark_name'] == 'atari'): # must import this to register envs from il_representations.envs import dm_control_envs # noqa: F401 @@ -89,17 +86,17 @@ def run(policy_path, benchmark, seed, n_rollouts, device_name, run_id): # sample some trajectories rng = np.random.RandomState(seed) - trajectories = il_rollout.generate_trajectories( - policy, vec_env, il_rollout.min_episodes(n_rollouts), rng=rng) + trajectories = il_rollout.generate_trajectories(policy, + vec_env, + il_rollout.min_episodes(n_rollouts), + rng=rng) # the "stats" dict has keys {return,len}_{min,max,mean,std} stats = il_rollout.rollout_stats(trajectories) - stats = collections.OrderedDict([(key, stats[key]) - for key in sorted(stats)]) + stats = collections.OrderedDict([(key, stats[key]) for key in sorted(stats)]) # print it out - kv_message = '\n'.join(f" {key}={value}" - for key, value in stats.items()) + kv_message = '\n'.join(f" {key}={value}" for key, value in stats.items()) logging.info(f"Evaluation stats on '{full_env_name}': {kv_message}") final_stats_dict = collections.OrderedDict([ diff --git a/src/il_representations/scripts/il_train.py b/src/il_representations/scripts/il_train.py index de512cef..5fbea458 100644 --- a/src/il_representations/scripts/il_train.py +++ b/src/il_representations/scripts/il_train.py @@ -67,9 +67,12 @@ def gail_defaults(): del _ -il_train_ex = Experiment('il_train', ingredients=[ - benchmark_ingredient, bc_ingredient, gail_ingredient, -]) +il_train_ex = Experiment('il_train', + ingredients=[ + benchmark_ingredient, + bc_ingredient, + gail_ingredient, + ]) @il_train_ex.config @@ -136,8 +139,7 @@ def make_policy(observation_space, action_space, encoder_or_path, lr_schedule=No @il_train_ex.capture -def do_training_bc(venv_chans_first, dataset, out_dir, bc, encoder, - device_name, final_pol_name): +def do_training_bc(venv_chans_first, dataset, out_dir, bc, encoder, device_name, final_pol_name): policy = make_policy(venv_chans_first.observation_space, venv_chans_first.action_space, encoder) color_space = auto_env.load_color_space() augmenter = StandardAugmentations.from_string_spec(bc['augs'], stack_color_space=color_space) @@ -201,8 +203,8 @@ def policy_constructor(observation_space, action_space, lr_schedule, use_sde=Fal learning_rate=gail['ppo_learning_rate'], ) color_space = auto_env.load_color_space() - augmenter = StandardAugmentations.from_string_spec( - gail['disc_augs'], stack_color_space=color_space) + augmenter = StandardAugmentations.from_string_spec(gail['disc_augs'], + stack_color_space=color_space) trainer = GAIL( venv_chans_first, dataset, @@ -252,16 +254,10 @@ def train(seed, algo, benchmark, encoder_path, freeze_encoder, _config): logging.info(f"Setting up '{algo}' IL algorithm") if algo == 'bc': - do_training_bc(dataset=dataset, - venv_chans_first=venv, - out_dir=log_dir, - encoder=encoder) + do_training_bc(dataset=dataset, venv_chans_first=venv, out_dir=log_dir, encoder=encoder) elif algo == 'gail': - do_training_gail(dataset=dataset, - venv_chans_first=venv, - out_dir=log_dir, - encoder=encoder) + do_training_gail(dataset=dataset, venv_chans_first=venv, out_dir=log_dir, encoder=encoder) else: raise NotImplementedError(f"Can't handle algorithm '{algo}'") diff --git a/src/il_representations/scripts/run_rep_learner.py b/src/il_representations/scripts/run_rep_learner.py index 4bbb629a..b462a96b 100644 --- a/src/il_representations/scripts/run_rep_learner.py +++ b/src/il_representations/scripts/run_rep_learner.py @@ -1,5 +1,4 @@ from glob import glob -import inspect import logging import os @@ -16,8 +15,7 @@ from il_representations.envs.config import benchmark_ingredient from il_representations.policy_interfacing import EncoderFeatureExtractor -represent_ex = Experiment('representation_learning', - ingredients=[benchmark_ingredient]) +represent_ex = Experiment('representation_learning', ingredients=[benchmark_ingredient]) @represent_ex.config @@ -54,10 +52,17 @@ def default_config(): @represent_ex.named_config def cosine_warmup_scheduler(): - algo_params = {"scheduler": LinearWarmupCosine, "scheduler_kwargs": {'warmup_epoch': 2, 'T_max': 10}} + algo_params = { + "scheduler": LinearWarmupCosine, + "scheduler_kwargs": { + 'warmup_epoch': 2, + 'T_max': 10 + } + } _ = locals() del _ + @represent_ex.named_config def ceb_breakout(): env_id = 'BreakoutNoFrameskip-v4' @@ -69,18 +74,21 @@ def ceb_breakout(): _ = locals() del _ + @represent_ex.named_config def tiny_epoch(): - demo_timesteps=5000 + demo_timesteps = 5000 _ = locals() del _ + @represent_ex.named_config def target_projection(): algo = algos.FixedVarianceTargetProjectedCEB _ = locals() del _ + @represent_ex.capture def get_random_traj(env, demo_timesteps): # Currently not designed for VecEnvs with n>1 @@ -96,8 +104,9 @@ def get_random_traj(env, demo_timesteps): def initialize_non_features_extractor(sb3_model): - # This is a hack to get around the fact that you can't initialize only some of the components of a SB3 policy - # upon creation, and we in fact want to keep the loaded representation frozen, but orthogonally initalize other + # This is a hack to get around the fact that you can't initialize only some + # of the components of a SB3 policy upon creation, and we in fact want to + # keep the loaded representation frozen, but orthogonally initalize other # components. sb3_model.policy.init_weights(sb3_model.policy.mlp_extractor, np.sqrt(2)) sb3_model.policy.init_weights(sb3_model.policy.action_net, 0.01) @@ -106,13 +115,12 @@ def initialize_non_features_extractor(sb3_model): @represent_ex.main -def run(benchmark, use_random_rollouts, algo, algo_params, - ppo_timesteps, ppo_finetune, pretrain_epochs, _config): +def run(benchmark, use_random_rollouts, algo, algo_params, ppo_timesteps, ppo_finetune, + pretrain_epochs, _config): # TODO fix to not assume FileStorageObserver always present log_dir = os.path.join(represent_ex.observers[0].dir, 'training_logs') os.mkdir(log_dir) - if isinstance(algo, str): algo = getattr(algos, algo) @@ -140,15 +148,18 @@ def run(benchmark, use_random_rollouts, algo, algo_params, encoder_checkpoint = model.encoder_checkpoints_path all_checkpoints = glob(os.path.join(encoder_checkpoint, '*')) latest_checkpoint = max(all_checkpoints, key=os.path.getctime) - encoder_feature_extractor_kwargs = {'features_dim': algo_params["representation_dim"], - 'encoder_path': latest_checkpoint} + encoder_feature_extractor_kwargs = { + 'features_dim': algo_params["representation_dim"], + 'encoder_path': latest_checkpoint + } # TODO figure out how to not have to set `ortho_init` to False for the whole policy - policy_kwargs = {'features_extractor_class': EncoderFeatureExtractor, - 'features_extractor_kwargs': encoder_feature_extractor_kwargs, - 'ortho_init': False} - ppo_model = PPO(policy=ActorCriticPolicy, env=venv, - verbose=1, policy_kwargs=policy_kwargs) + policy_kwargs = { + 'features_extractor_class': EncoderFeatureExtractor, + 'features_extractor_kwargs': encoder_feature_extractor_kwargs, + 'ortho_init': False + } + ppo_model = PPO(policy=ActorCriticPolicy, env=venv, verbose=1, policy_kwargs=policy_kwargs) ppo_model = initialize_non_features_extractor(ppo_model) ppo_model.learn(total_timesteps=ppo_timesteps) diff --git a/src/il_representations/test_support/configuration.py b/src/il_representations/test_support/configuration.py index b1b099de..96e1b2c5 100644 --- a/src/il_representations/test_support/configuration.py +++ b/src/il_representations/test_support/configuration.py @@ -2,8 +2,7 @@ from os import path CURRENT_DIR = path.dirname(path.abspath(__file__)) -TEST_DATA_DIR = path.abspath( - path.join(CURRENT_DIR, '..', '..', '..', 'tests', 'data')) +TEST_DATA_DIR = path.abspath(path.join(CURRENT_DIR, '..', '..', '..', 'tests', 'data')) COMMON_TEST_CONFIG = { 'venv_parallel': False, 'n_envs': 2, @@ -14,8 +13,7 @@ 'benchmark_name': 'atari', 'atari_env_id': 'PongNoFrameskip-v4', 'atari_demo_paths': { - 'PongNoFrameskip-v4': path.join(TEST_DATA_DIR, 'atari', - 'pong.npz'), + 'PongNoFrameskip-v4': path.join(TEST_DATA_DIR, 'atari', 'pong.npz'), }, **COMMON_TEST_CONFIG, }, @@ -23,8 +21,7 @@ 'benchmark_name': 'magical', 'magical_env_prefix': 'MoveToRegion', 'magical_demo_dirs': { - 'MoveToRegion': path.join(TEST_DATA_DIR, 'magical', - 'move-to-region'), + 'MoveToRegion': path.join(TEST_DATA_DIR, 'magical', 'move-to-region'), }, **COMMON_TEST_CONFIG, }, @@ -32,8 +29,7 @@ 'benchmark_name': 'dm_control', 'dm_control_env': 'reacher-easy', 'dm_control_demo_patterns': { - 'reacher-easy': - path.join(TEST_DATA_DIR, 'dm_control', 'reacher-easy-*.pkl.gz'), + 'reacher-easy': path.join(TEST_DATA_DIR, 'dm_control', 'reacher-easy-*.pkl.gz'), }, **COMMON_TEST_CONFIG, }, diff --git a/tests/conftest.py b/tests/conftest.py index 37a3cfdf..08a1091d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,7 @@ from il_representations.scripts.il_test import il_test_ex as _il_test_ex from il_representations.scripts.il_train import il_train_ex as _il_train_ex -from il_representations.scripts.run_rep_learner import \ - represent_ex as _represent_ex +from il_representations.scripts.run_rep_learner import represent_ex as _represent_ex @pytest.fixture diff --git a/tests/test_base_algos.py b/tests/test_base_algos.py index 6f42d7c5..676a4559 100644 --- a/tests/test_base_algos.py +++ b/tests/test_base_algos.py @@ -1,10 +1,6 @@ import inspect -import warnings -for category in [FutureWarning, DeprecationWarning, PendingDeprecationWarning]: - warnings.filterwarnings("ignore", category=category) import pytest -from sacred.observers import FileStorageObserver from il_representations import algos from il_representations.test_support.configuration import BENCHMARK_TEST_CONFIGS @@ -12,20 +8,27 @@ def is_representation_learner(el): try: - return issubclass(el, algos.RepresentationLearner) and el != algos.RepresentationLearner and el not in algos.WIP_ALGOS + return issubclass(el, algos.RepresentationLearner + ) and el != algos.RepresentationLearner and el not in algos.WIP_ALGOS except TypeError: return False -@pytest.mark.parametrize("algo", [el[1] for el in inspect.getmembers(algos) if is_representation_learner(el[1])]) +@pytest.mark.parametrize( + "algo", [el[1] for el in inspect.getmembers(algos) if is_representation_learner(el[1])]) @pytest.mark.parametrize("benchmark_cfg", BENCHMARK_TEST_CONFIGS) def test_algo(algo, benchmark_cfg, represent_ex): - represent_ex.run(config_updates={'pretrain_epochs': 1, - 'demo_timesteps': 32, - 'batch_size': 7, - 'unit_test_max_train_steps': 2, - 'algo_params': {'representation_dim': 3}, - 'algo': algo, - 'use_random_rollouts': False, - 'benchmark': benchmark_cfg, - 'ppo_finetune': False}) + represent_ex.run( + config_updates={ + 'pretrain_epochs': 1, + 'demo_timesteps': 32, + 'batch_size': 7, + 'unit_test_max_train_steps': 2, + 'algo_params': { + 'representation_dim': 3 + }, + 'algo': algo, + 'use_random_rollouts': False, + 'benchmark': benchmark_cfg, + 'ppo_finetune': False + }) diff --git a/tests/test_il_train_test.py b/tests/test_il_train_test.py index 8f928577..0293c583 100644 --- a/tests/test_il_train_test.py +++ b/tests/test_il_train_test.py @@ -8,8 +8,7 @@ @pytest.mark.parametrize("benchmark_cfg", BENCHMARK_TEST_CONFIGS) @pytest.mark.parametrize("algo", ["bc", "gail"]) -def test_il_train_test(benchmark_cfg, algo, il_train_ex, il_test_ex, - file_observer): +def test_il_train_test(benchmark_cfg, algo, il_train_ex, il_test_ex, file_observer): """Simple smoke test for training/testing IL code.""" common_cfg = { 'benchmark': benchmark_cfg, @@ -18,22 +17,22 @@ def test_il_train_test(benchmark_cfg, algo, il_train_ex, il_test_ex, final_pol_name = 'last_test_policy.pt' # train - il_train_ex.run(config_updates={ - 'algo': algo, - 'final_pol_name': final_pol_name, - # these defaults make training cheap - **FAST_IL_TRAIN_CONFIG, - **common_cfg, - }) + il_train_ex.run( + config_updates={ + 'algo': algo, + 'final_pol_name': final_pol_name, + # these defaults make training cheap + **FAST_IL_TRAIN_CONFIG, + **common_cfg, + }) # FIXME(sam): same comment as elsewhere: should have a better way of # getting at saved policies. log_dir = file_observer.dir # test policy_path = os.path.join(log_dir, final_pol_name) - il_test_ex.run( - config_updates={ - 'n_rollouts': 2, - 'policy_path': policy_path, - **common_cfg, - }) + il_test_ex.run(config_updates={ + 'n_rollouts': 2, + 'policy_path': policy_path, + **common_cfg, + }) diff --git a/tests/test_reload_policy.py b/tests/test_reload_policy.py index cc82eb32..fecef39c 100644 --- a/tests/test_reload_policy.py +++ b/tests/test_reload_policy.py @@ -20,7 +20,9 @@ def test_reload_policy(algo, represent_ex, il_train_ex, file_observer): 'pretrain_epochs': 1, 'batch_size': 7, 'unit_test_max_train_steps': 2, - 'algo_params': {'representation_dim': 3}, + 'algo_params': { + 'representation_dim': 3 + }, 'algo': MoCo, 'use_random_rollouts': False, 'benchmark': BENCHMARK_TEST_CONFIGS[0], @@ -29,9 +31,7 @@ def test_reload_policy(algo, represent_ex, il_train_ex, file_observer): # train BC using learnt representation encoder_list = glob.glob( - os.path.join( - file_observer.dir, - 'training_logs/checkpoints/representation_encoder/*.ckpt')) + os.path.join(file_observer.dir, 'training_logs/checkpoints/representation_encoder/*.ckpt')) policy_path = encoder_list[0] il_train_ex.run( config_updates={ From 293a47beed6495ad0081819738cbf0d43ff071c4 Mon Sep 17 00:00:00 2001 From: Sam Toyer Date: Wed, 7 Oct 2020 16:07:47 -0700 Subject: [PATCH 5/5] Update Docker image --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 6f817070..7569e716 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -3,7 +3,7 @@ version: 2.1 jobs: build-and-test: docker: - - image: humancompatibleai/il-representations:2020.08.03-r3 + - image: humancompatibleai/il-representations:2020.10.07-r1 steps: - checkout - run: