diff --git a/.circleci/config.yml b/.circleci/config.yml index 14c7841c..e6a61dec 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -6,7 +6,7 @@ jobs: docker: # this is also used in gcp_cluster_template.yaml (if you change this file, # remember to update all your cluster configs too) - - image: humancompatibleai/il-representations:2021.02.22 + - image: humancompatibleai/il-representations:2021.03.16 steps: - checkout - run: diff --git a/Dockerfile b/Dockerfile index 613c8058..8e1c389f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,6 +21,7 @@ RUN apt-get update -q \ wget \ xpra \ xserver-xorg-dev \ + libopenmpi-dev \ libxrandr2 \ libxss1 \ libxcursor1 \ diff --git a/cloud/gcp_cluster_template.yaml b/cloud/gcp_cluster_template.yaml index aa166b6d..9e65b85b 100644 --- a/cloud/gcp_cluster_template.yaml +++ b/cloud/gcp_cluster_template.yaml @@ -32,7 +32,7 @@ docker: # the image below is taken from .circleci/config.yml # (if you change this file or that file, then remember to change the other # too) - image: "humancompatibleai/il-representations:2021.02.22" + image: "humancompatibleai/il-representations:2021.03.16" container_name: "il-rep-ray-tune" # Set to true if to always force-pull the latest image version (no cache). pull_before_run: False diff --git a/requirements.txt b/requirements.txt index 369306da..95a8c019 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,3 +48,8 @@ git+git://github.com/HumanCompatibleAI/stable-baselines3@ad902dd5f9d4afeef347897 magical-il~=0.0.1a4 dm_control~=0.0.319497192 git+git://github.com/denisyarats/dmc2gym@6e34d8acf18e92f0ea0a38ecee9564bdf2549076#egg=dmc2gym + +# # NOTE TO SELF: Developing locally with updated RB, need to change that here once that PR is merged +# # Target RB branch: ilr_wrappers +# # [Also need to make branch un-private somehow, possibly giving CircleCI a checkout key?...] +# git+ssh://git@github.com/HumanCompatibleAI/realistic-benchmarks@master#egg=realistic-benchmarks diff --git a/src/il_representations/algos/encoders.py b/src/il_representations/algos/encoders.py index 1fb5091a..e4c8a855 100644 --- a/src/il_representations/algos/encoders.py +++ b/src/il_representations/algos/encoders.py @@ -14,6 +14,7 @@ from gym import spaces from il_representations.algos.utils import independent_multivariate_normal +from il_representations.utils import ForkedPdb import functools """ @@ -407,7 +408,6 @@ def infer_action_shape_info(action_space, action_embedding_dim): and the action_embedding_dim """ # Machinery for turning raw actions into vectors. - if isinstance(action_space, spaces.Discrete): # If actions are discrete, this is done via an Embedding. action_processor = nn.Embedding(num_embeddings=action_space.n, embedding_dim=action_embedding_dim) @@ -439,6 +439,7 @@ def __init__(self, obs_space, representation_dim, action_space, learn_scale=Fals super().__init__(obs_space, representation_dim, obs_encoder_cls, learn_scale=learn_scale) + # TODO make this work better for more complex action types like those in Minecraft self.processed_action_dim, self.action_shape, self.action_processor = infer_action_shape_info(action_space, action_embedding_dim) # Machinery for aggregating information from an arbitrary number of actions into a single vector, @@ -476,7 +477,7 @@ def forward(self, x, traj_info, action=False): if self.action_encoder is not None: output, (hidden, cell) = self.action_encoder(processed_actions) else: - hidden = torch.mean(processed_actions, dim=1) + hidden = torch.mean(processed_actions.float(), dim=1) action_encoding_vector = torch.squeeze(hidden) assert action_encoding_vector.shape[0] == batch_dim, \ diff --git a/src/il_representations/algos/representation_learner.py b/src/il_representations/algos/representation_learner.py index 234ec6f7..ef99aa7c 100644 --- a/src/il_representations/algos/representation_learner.py +++ b/src/il_representations/algos/representation_learner.py @@ -290,7 +290,6 @@ def learn(self, datasets, batches_per_epoch, n_epochs, n_trajs=None, callbacks=( for step, batch in enumerate(dataloader): # Construct batch (currently just using Torch's default batch-creator) contexts, targets, traj_ts_info, extra_context = self.unpack_batch(batch) - # Use an algorithm-specific augmentation strategy to augment either # just context, or both context and targets contexts, targets = self._prep_tensors(contexts), self._prep_tensors(targets) diff --git a/src/il_representations/data/write_dataset.py b/src/il_representations/data/write_dataset.py index 0bcf1e2e..1da50a3b 100644 --- a/src/il_representations/data/write_dataset.py +++ b/src/il_representations/data/write_dataset.py @@ -59,6 +59,7 @@ def write_frames(out_file_path, meta_dict, frame_dicts, n_traj=None): with wds.TarWriter(out_file_path, keep_meta=True, compress=True) \ as writer: # noqa: E207 # first write _metadata.meta.pickle containing the benchmark config + meta_dict['frames'] = len(frame_dicts) writer.dwrite(key='_metadata', meta_pickle=meta_dict) # now write each frame in each trajectory for frame_num, frame_dict in enumerate(frame_dicts): diff --git a/src/il_representations/envs/auto.py b/src/il_representations/envs/auto.py index bd819426..9d140d00 100644 --- a/src/il_representations/envs/auto.py +++ b/src/il_representations/envs/auto.py @@ -1,6 +1,7 @@ """Code for automatically loading data, creating vecenvs, etc. based on Sacred configuration.""" +from functools import partial import glob import logging import os @@ -18,10 +19,10 @@ from il_representations.envs.dm_control_envs import load_dataset_dm_control from il_representations.envs.magical_envs import (get_env_name_magical, load_dataset_magical) -from il_representations.envs.minecraft_envs import (MinecraftVectorWrapper, - get_env_name_minecraft, +from il_representations.envs.minecraft_envs import (get_env_name_minecraft, load_dataset_minecraft) from il_representations.scripts.utils import update as dict_update +from il_representations.envs.utils import wrap_env ERROR_MESSAGE = "no support for benchmark_name={benchmark_name!r}" @@ -53,6 +54,7 @@ def benchmark_is_available(benchmark_name): # we check whether minecraft is installed by importing minerl try: import minerl # noqa: F401 + import realistic_benchmarks return True, None except ImportError as ex: return False, "MineRL not installed, cannot use Minecraft " \ @@ -62,7 +64,7 @@ def benchmark_is_available(benchmark_name): @env_cfg_ingredient.capture -def load_dict_dataset(benchmark_name, n_traj=None): +def load_dict_dataset(benchmark_name, n_traj, frames_per_traj): """Load a dict-type dataset. Also see load_wds_datasets, which instead lods a set of datasets that have been stored in a webdataset-compatible format.""" @@ -73,7 +75,7 @@ def load_dict_dataset(benchmark_name, n_traj=None): elif benchmark_name == 'atari': dataset_dict = load_dataset_atari(n_traj=n_traj) elif benchmark_name == 'minecraft': - dataset_dict = load_dataset_minecraft(n_traj=n_traj) + dataset_dict = load_dataset_minecraft(n_traj=n_traj, frames_per_traj=frames_per_traj) else: raise NotImplementedError(ERROR_MESSAGE.format(**locals())) @@ -110,9 +112,10 @@ def _get_venv_opts(n_envs, venv_parallel, parallel_workers): return n_envs, venv_parallel, parallel_workers + @env_cfg_ingredient.capture def load_vec_env(benchmark_name, dm_control_full_env_names, - dm_control_frame_stack, minecraft_max_env_steps): + dm_control_frame_stack, minecraft_max_env_steps, minecraft_wrappers): """Create a vec env for the selected benchmark task and wrap it with any necessary wrappers.""" n_envs, venv_parallel, parallel_workers = _get_venv_opts() @@ -161,7 +164,7 @@ def load_vec_env(benchmark_name, dm_control_full_env_names, n_envs=1, # TODO fix this eventually; currently hitting error # noted here: https://github.com/minerllabs/minerl/issues/177 parallel=venv_parallel, - wrapper_class=MinecraftVectorWrapper, + wrapper_class=partial(wrap_env, wrappers=minecraft_wrappers), max_episode_steps=minecraft_max_env_steps) raise NotImplementedError(ERROR_MESSAGE.format(**locals())) diff --git a/src/il_representations/envs/config.py b/src/il_representations/envs/config.py index 6a6c3bff..3c3bb527 100644 --- a/src/il_representations/envs/config.py +++ b/src/il_representations/envs/config.py @@ -3,8 +3,15 @@ vecenvs (`venv_opts_ingredient`), and for loading data (`env_data_ingredient`).""" import os - +import logging from sacred import Ingredient +from il_representations.envs.utils import MinecraftPOVWrapper, Testing2500StepLimitWrapper +try: + import realistic_benchmarks.wrappers as rb_wrappers +except ImportError as e: + print(f"Hit error: {e}") + logging.info("Realistic Benchmarks is not installed; as a result much Minecraft functionality will not work") + ALL_BENCHMARK_NAMES = {"atari", "magical", "dm_control", "minecraft"} @@ -69,11 +76,49 @@ def env_cfg_defaults(): # Minecraft-specific config variables # ############################### minecraft_max_env_steps = None + minecraft_wrappers = list() + _ = locals() + del _ + +@env_cfg_ingredient.named_config +def wrappers_frames_only_camera_disc_no_limit(): + minecraft_wrappers = [rb_wrappers.CameraDiscretizationWrapper, rb_wrappers.ActionFlatteningWrapper, + MinecraftPOVWrapper] # + _ = locals() + del _ + +@env_cfg_ingredient.named_config +def wrappers_frames_only_camera_disc(): + minecraft_wrappers = [rb_wrappers.CameraDiscretizationWrapper, rb_wrappers.ActionFlatteningWrapper, + MinecraftPOVWrapper, Testing2500StepLimitWrapper] # + _ = locals() + del _ +@env_cfg_ingredient.named_config +def treechop_wrappers_with_frameskip(): + minecraft_wrappers = [rb_wrappers.CameraDiscretizationWrapper, rb_wrappers.ActionFlatteningWrapper, + MinecraftPOVWrapper, rb_wrappers.FrameSkip] # _ = locals() del _ +@env_cfg_ingredient.named_config +def wrappers_frames_only_obfuscated(): + minecraft_wrappers = [rb_wrappers.ActionFlatteningWrapper, + MinecraftPOVWrapper, Testing2500StepLimitWrapper] # + _ = locals() + del _ + +@env_cfg_ingredient.named_config +def wrappers_obs_flatten_camera_disc(): + minecraft_wrappers = [rb_wrappers.CameraDiscretizationWrapper, rb_wrappers.ActionFlatteningWrapper, + rb_wrappers.ObservationFlatteningWrapper, Testing2500StepLimitWrapper] + _ = locals() + del _ + + + + # see venv_opts_defaults docstring for description of this ingredient venv_opts_ingredient = Ingredient('venv_opts') diff --git a/src/il_representations/envs/minecraft_envs.py b/src/il_representations/envs/minecraft_envs.py index ce38874f..488b4be8 100644 --- a/src/il_representations/envs/minecraft_envs.py +++ b/src/il_representations/envs/minecraft_envs.py @@ -4,15 +4,41 @@ from gym import Wrapper, spaces import numpy as np - +from copy import deepcopy from il_representations.envs.config import (env_cfg_ingredient, - env_data_ingredient, - venv_opts_ingredient) + env_data_ingredient ) +from il_representations.envs.utils import wrap_env + + +def optional_observation_map(env, inner_obs): + if hasattr(env, 'observation'): + return env.observation(inner_obs) + else: + return inner_obs + + +def optional_action_map(env, inner_action): + if hasattr(env, 'wrap_action'): + return env.wrap_action(inner_action) + else: + return inner_action +def remove_iterator_dimension(dict_obs_or_act): + output_dict = dict() + for k in dict_obs_or_act.keys(): + if isinstance(dict_obs_or_act[k], dict): + output_dict[k] = remove_iterator_dimension(dict_obs_or_act[k]) + else: + output_dict[k] = dict_obs_or_act[k][0] + return output_dict + @env_cfg_ingredient.capture def get_env_name_minecraft(task_name): - return f"MineRL{task_name}-v0" + if task_name in ('FindCaves'): + return f"{task_name}-v0" + else: + return f"MineRL{task_name}-v0" @env_data_ingredient.capture @@ -25,33 +51,53 @@ def _get_data_root(data_root): # even though it can only be notated as a capture function for one # ingredient at a time @env_cfg_ingredient.capture -def load_dataset_minecraft(n_traj=None, chunk_length=100): +def load_dataset_minecraft(minecraft_wrappers, n_traj, frames_per_traj, chunk_length=100): import minerl # lazy-load in case it is not installed + import realistic_benchmarks.envs.envs # Registers new environments + from realistic_benchmarks.utils import DummyEnv data_root = _get_data_root() env_name = get_env_name_minecraft() minecraft_data_root = os.path.join(data_root, 'minecraft') - data_iterator = minerl.data.make(environment=env_name, + data_pipeline = minerl.data.make(environment=env_name, data_dir=minecraft_data_root) - appended_trajectories = {'obs': [], 'acts': [], 'dones': []} + appended_trajectories = {'obs': [], 'acts': [], 'dones': [], 'next_obs': []} start_time = time.time() - for current_state, action, reward, next_state, done in data_iterator.batch_iter(batch_size=1, - num_epochs=1, - seq_len=chunk_length, - epoch_size=n_traj): - # Data returned from the data_iterator is in batches of size `batch_size` x `chunk_size` - # The zero-indexing is to remove the extra extraneous `batch_size` dimension, - # which has been hardcoded to 1 - appended_trajectories['obs'].append(MinecraftVectorWrapper.transform_obs(current_state)[0]) - appended_trajectories['acts'].append(MinecraftVectorWrapper.extract_action(action)[0]) - appended_trajectories['dones'].append(done[0]) - # Now, we need to go through and construct `next_obs` values, which aren't natively returned - # by the environment - merged_trajectories = {k: np.concatenate(v, axis=0) for k, v in appended_trajectories.items()} - merged_trajectories = construct_next_obs(merged_trajectories) + + env_spec = deepcopy(data_pipeline.spec) + dummy_env = DummyEnv(action_space=env_spec._action_space, + observation_space=env_spec._observation_space) + wrapped_dummy_env = wrap_env(dummy_env, minecraft_wrappers) + timesteps = 0 + + trajectory_names = data_pipeline.get_trajectory_names() + trajectory_subset = np.random.choice(trajectory_names, size=n_traj) + for trajectory_name in trajectory_subset: + data_loader = data_pipeline.load_data(trajectory_name) + traj_frame_count = 0 + for current_obs, action, reward, next_obs, done in data_loader: + wrapped_obs = optional_observation_map(wrapped_dummy_env, current_obs) + wrapped_next_obs = optional_observation_map(wrapped_dummy_env, next_obs) + wrapped_action = optional_action_map(wrapped_dummy_env, action) + appended_trajectories['obs'].append(wrapped_obs) + appended_trajectories['next_obs'].append(wrapped_next_obs) + appended_trajectories['acts'].append(wrapped_action) + appended_trajectories['dones'].append(done) + traj_frame_count += 1 + timesteps += 1 + + if timesteps % 1000 == 0: + print(f"{timesteps} timesteps loaded") + + # if frames_per_traj is None, collect the whole trajectory + if frames_per_traj is not None and traj_frame_count == frames_per_traj: + appended_trajectories['dones'][-1] = True + break end_time = time.time() + for k in appended_trajectories: + appended_trajectories[k] = np.array(appended_trajectories[k]) logging.info(f"Minecraft trajectory collection took {round(end_time - start_time, 2)} seconds to complete") - merged_trajectories['dones'][-1] = True - return merged_trajectories + appended_trajectories['dones'][-1] = True + return appended_trajectories def construct_next_obs(trajectories_dict): @@ -62,19 +108,26 @@ def construct_next_obs(trajectories_dict): dones_locations = np.append(dones_locations, -1) prior_dones_loc = 0 all_next_obs = [] + print(f"Done locations to process {dones_locations}") for done_loc in dones_locations: if done_loc == -1: trajectory_obs = trajectories_dict['obs'][prior_dones_loc:] else: trajectory_obs = trajectories_dict['obs'][prior_dones_loc:done_loc+1] next_obs = trajectory_obs[1:] - next_obs = np.append(next_obs, np.expand_dims(trajectory_obs[-1], axis=0), axis=0) #duplicate final obs for final next_obs + expanded_thing = np.expand_dims(trajectory_obs[-1], axis=0) + next_obs = np.append(next_obs, expanded_thing, axis=0) #duplicate final obs for final next_obs all_next_obs.append(next_obs) + prior_dones_loc = done_loc + + del next_obs + del trajectory_obs if len(all_next_obs) == 1: - merged_next_obs = all_next_obs[0] + all_next_obs = all_next_obs[0] else: - merged_next_obs = np.concatenate(all_next_obs) - trajectories_dict['next_obs'] = merged_next_obs + print("Concatenating") + all_next_obs = np.concatenate(all_next_obs) #maybe this will make memory less horrible? + trajectories_dict['next_obs'] = all_next_obs return trajectories_dict @@ -86,45 +139,4 @@ def channels_first(el): return (el[2], el[0], el[1]) else: - raise NotImplementedError("Input must be either array or tuple") - - -class MinecraftVectorWrapper(Wrapper): - """ - Currently, RepL code only works with pixel inputs, and imitation can only work with vector (rather than dict) - action spaces. So, we currently (1) only allow VectorObfuscated environments (where the action dictionary - has been processed into a vector), and (2) extract the observation space to only save the pixels, before we load - the data in as a il_representations dataset - """ - def __init__(self, env): - super().__init__(env) - assert 'vector' in env.action_space.spaces.keys(), "Wrapper is only implemented to work with Vector Obfuscated envs" - self.action_space = env.action_space.spaces['vector'] - pov_space = env.observation_space.spaces['pov'] - transposed_pov_space = spaces.Box(low=channels_first(pov_space.low), - high=channels_first(pov_space.high), - shape=channels_first(pov_space.shape), - dtype=np.uint8) - self.observation_space = transposed_pov_space - - @staticmethod - def transform_obs(obs): - return channels_first(obs['pov']).astype(np.uint8) - - @staticmethod - def extract_action(action): - return action['vector'] - - @staticmethod - def dictify_action(action): - return {'vector': action} - - def step(self, action): - obs, rew, dones, infos = self.env.step(MinecraftVectorWrapper.dictify_action(action)) - transformed_obs = MinecraftVectorWrapper.transform_obs(obs) - return transformed_obs, rew, dones, infos - - def reset(self): - obs = self.env.reset() - return MinecraftVectorWrapper.transform_obs(obs) - + raise NotImplementedError("Input must be either array or tuple") \ No newline at end of file diff --git a/src/il_representations/envs/utils.py b/src/il_representations/envs/utils.py index 6d8bc60b..54addb25 100644 --- a/src/il_representations/envs/utils.py +++ b/src/il_representations/envs/utils.py @@ -1,4 +1,38 @@ import gym +from gym import ObservationWrapper +from gym.wrappers import TimeLimit +import numpy as np + + +class MinecraftPOVWrapper(ObservationWrapper): + def __init__(self, env): + super().__init__(env) + non_transposed_shape = self.env.observation_space['pov'].shape + self.high = np.max(self.env.observation_space['pov'].high) + transposed_shape = (non_transposed_shape[2], + non_transposed_shape[0], + non_transposed_shape[1]) + # Note: this assumes the Box is of the form where low/high values are vector but need to be scalar + transposed_obs_space = gym.spaces.Box(low=0, + high=1, + shape=transposed_shape) + self.observation_space = transposed_obs_space + + def observation(self, obs): + # Minecraft returns shapes in NHWC by default, and with unnormalized pixel ranges + return np.swapaxes(obs['pov'], -1, -3)/self.high + + +# TODO This is just a hack for dealing with the fact that currently FindCaves +# never reaches an episode termination condition +class Testing2500StepLimitWrapper(TimeLimit): + def __init__(self, env): + super().__init__(env, 2500) + +def wrap_env(env, wrappers): + for wrapper in wrappers: + env = wrapper(env) + return env def serialize_gym_space(space): @@ -6,10 +40,16 @@ def serialize_gym_space(space): (i.e. for pickles that will be transferred between machines running different versions of Gym).""" if isinstance(space, gym.spaces.Box): + if not np.isscalar(space.low): + # This is to fix a weird issue where Box requires the shape to not be a vector if the + # low and high values also are + space_shape = None + else: + space_shape = space.shape return _KwargSerialisableObject(gym.spaces.Box, { 'low': space.low, 'high': space.high, - 'shape': space.shape, + 'shape': space_shape, 'dtype': space.dtype, }) elif isinstance(space, gym.spaces.Discrete): diff --git a/src/il_representations/il/bc_support.py b/src/il_representations/il/bc_support.py index 4054b2c1..48988cff 100644 --- a/src/il_representations/il/bc_support.py +++ b/src/il_representations/il/bc_support.py @@ -1,7 +1,21 @@ """Support code for using imitation's BC implementation.""" +import collections +import json +import logging import os import torch as th +import imitation.data.rollout as il_rollout + + +class MultiCallback: + """Callback that allows multiple callbacks to be passed into `on_epoch_end`""" + def __init__(self, callbacks): + self.callbacks = callbacks + + def __call__(self, **kwargs): + for callback in self.callbacks: + callback(**kwargs) class BCModelSaver: @@ -24,3 +38,42 @@ def __call__(self, **kwargs): th.save(self.policy, save_path) print(f"Saved policy to {save_path}!") self.last_save_batches = self.batch_count + + +class IntermediateRolloutEvaluator: + """Callback that saves BC policy every N epochs.""" + def __init__(self, policy, vec_env, save_dir, epoch_length, evaluate_interval_batches, n_rollouts): + self.policy = policy + self.vec_env = vec_env + self.save_dir = save_dir + self.last_save_batches = 0 + self.evaluate_interval_batches = evaluate_interval_batches + self.batch_count = 0 + self.epoch_length = epoch_length + self.n_rollouts = n_rollouts + + def get_stats(self): + # Stolen from il_test + trajectories = il_rollout.generate_trajectories( + self.policy, self.vec_env, il_rollout.min_episodes(self.n_rollouts)) + stats = il_rollout.rollout_stats(trajectories) + stats = collections.OrderedDict([(key, stats[key]) + for key in sorted(stats)]) + return stats + + def __call__(self, **kwargs): + """It is assumed that this is called on epoch end.""" + self.batch_count += self.epoch_length + if self.batch_count >= self.last_save_batches + self.evaluate_interval_batches: + stats = self.get_stats() + kv_message = '\n'.join(f" {key}={value}" + for key, value in stats.items()) + logging.info(f"Evaluation stats at '{self.batch_count:08d}' batches: {kv_message}") + + os.makedirs(self.save_dir, exist_ok=True) + save_filename = f'evaluation_{self.batch_count:08d}_batches.json' + save_path = os.path.join(self.save_dir, save_filename) + with open(save_path, 'w') as fp: + json.dump(stats, fp, indent=2, sort_keys=False) + print(f"Rolled out {self.n_rollouts} trajectories, saved stats to to {save_path}!") + self.last_save_batches = self.batch_count diff --git a/src/il_representations/scripts/chain_configs.py b/src/il_representations/scripts/chain_configs.py index ad3d51bb..e9a077f3 100644 --- a/src/il_representations/scripts/chain_configs.py +++ b/src/il_representations/scripts/chain_configs.py @@ -1,4 +1,5 @@ from il_representations.scripts.utils import StagesToRun +import logging from ray import tune # TODO(sam): GAIL configs @@ -317,6 +318,17 @@ def cfg_repl_temporal_cpc(): _ = locals() del _ + @experiment_obj.named_config + def cfg_data_repl_minecraft_survival(): + """Training on specifically survival data, only to be used with Minecraft!""" + repl = { + 'dataset_configs': [{'type': 'frames_only_demos', + 'env_cfg': {'task_name': 'ObtainDiamondSurvivalVectorObf'}}], + 'is_multitask': True + } + _ = locals() + del _ + @experiment_obj.named_config def cfg_data_repl_demos_random(): """Training on both demos and random rollouts for the current @@ -329,7 +341,7 @@ def cfg_data_repl_demos_random(): @experiment_obj.named_config def cfg_data_repl_random(): - """Training on both demos and random rollouts for the current + """Training on only random rollouts for the current environment.""" repl = { 'dataset_configs': [{'type': 'random'}], diff --git a/src/il_representations/scripts/convert_minecraft_data.sh b/src/il_representations/scripts/convert_minecraft_data.sh index 55a41340..8191f91a 100644 --- a/src/il_representations/scripts/convert_minecraft_data.sh +++ b/src/il_representations/scripts/convert_minecraft_data.sh @@ -1,2 +1,5 @@ python -m il_representations.scripts.mkdataset_demos run with n_traj_total=10 \ - env_cfg.benchmark_name=minecraft env_cfg.task_name=NavigateVectorObf venv_opts.venv_parallel=False \ No newline at end of file + env_cfg.use_dict_wrappers env_cfg.benchmark_name=minecraft env_cfg.task_name=FindCaves \ + venv_opts.venv_parallel=False + + diff --git a/src/il_representations/scripts/il_train.py b/src/il_representations/scripts/il_train.py index 96482701..da21bde5 100644 --- a/src/il_representations/scripts/il_train.py +++ b/src/il_representations/scripts/il_train.py @@ -26,12 +26,18 @@ from il_representations.envs.config import (env_cfg_ingredient, env_data_ingredient, venv_opts_ingredient) -from il_representations.il.bc_support import BCModelSaver +from il_representations.il.bc_support import BCModelSaver, IntermediateRolloutEvaluator, MultiCallback from il_representations.il.disc_rew_nets import ImageDiscrimNet from il_representations.il.gail_pol_save import GAILSavePolicyCallback from il_representations.il.score_logging import SB3ScoreLoggingCallback from il_representations.policy_interfacing import EncoderFeatureExtractor from il_representations.utils import freeze_params +from il_representations.scripts.utils import print_policy_info + +try: + import realistic_benchmarks.policies as rb_policies +except ImportError: + logging.info("Realistic Benchmarks is not installed; as a result much Minecraft functionality will not work") bc_ingredient = Ingredient('bc') @@ -49,15 +55,22 @@ def bc_defaults(): # (however, large numbers prevent us from having to recreate the # data iterator frequently) nominal_length = int(1e6) - save_every_n_batches = nominal_length - + save_every_n_batches = nominal_length # equivalent to "do this every epoch" + evaluate_every_n_batches = None # equivalent to "do this every epoch" + n_evaluation_rollouts = 10 _ = locals() del _ - +# +# @bc_ingredient.named_config +# def evaluate_every_10000_batches(): +# batch_size = 32 +# nominal_length = int(batch_size*1e4) +# evaluate_every_n_batches = nominal_length +# _ = locals() +# del _ gail_ingredient = Ingredient('gail') - @gail_ingredient.config def gail_defaults(): # These default settings are copied from @@ -136,6 +149,13 @@ def default_config(): encoder_path = None # file name for final policy final_pol_name = 'policy_final.pt' + # Do we want to save the encoder after IL training? + # This is useful for experiments comparing off-task BC pretraining + # to off-task RepL pretraining + save_output_encoder = False + # The path at which the encoder after IL training will be saved + # This is only used if `save_output_encoder` = True + output_encoder_path = 'il_encoder.ckpt' # dataset configurations for webdataset code # (you probably don't want to change this) dataset_configs = [{'type': 'demos'}] @@ -155,16 +175,29 @@ def default_config(): representation_dim=128, obs_encoder_cls_kwargs={} ) - + policy_class = sb3_pols.ActorCriticCnnPolicy + extra_policy_kwargs = dict(features_extractor_class=EncoderFeatureExtractor) + add_env_to_policy_kwargs = False _ = locals() del _ +@il_train_ex.named_config +def minecraft_action_wrapped(): + # TODO need to define specific just-POV features extractor + policy_class = rb_policies.SpaceFlatteningActorCriticPolicy + add_env_to_policy_kwargs = True + _ = locals() + del _ @il_train_ex.capture -def make_policy(observation_space, +def make_policy(venv, + observation_space, action_space, encoder_or_path, encoder_kwargs, + policy_class, + extra_policy_kwargs, + add_env_to_policy_kwargs, lr_schedule=None): # TODO(sam): this should be unified with the representation learning code # so that it can be configured in the same way, with the same default @@ -191,13 +224,15 @@ def make_policy(observation_space, else: encoder = BaseEncoder(observation_space, **encoder_kwargs) policy_kwargs = { - 'features_extractor_class': EncoderFeatureExtractor, 'features_extractor_kwargs': { "encoder": encoder, }, **common_policy_kwargs, + **extra_policy_kwargs } - policy = sb3_pols.ActorCriticCnnPolicy(**policy_kwargs) + if add_env_to_policy_kwargs: + policy_kwargs['env'] = venv + policy = policy_class(**policy_kwargs) return policy @@ -217,8 +252,10 @@ def add_infos(data_iter): @il_train_ex.capture def do_training_bc(venv_chans_first, demo_webdatasets, out_dir, bc, encoder, - device_name, final_pol_name, shuffle_buffer_size): - policy = make_policy(observation_space=venv_chans_first.observation_space, + device_name, final_pol_name, shuffle_buffer_size, + save_output_encoder, output_encoder_path): + policy = make_policy(venv=venv_chans_first, + observation_space=venv_chans_first.observation_space, action_space=venv_chans_first.action_space, encoder_or_path=encoder) color_space = auto_env.load_color_space() @@ -250,23 +287,43 @@ def do_training_bc(venv_chans_first, demo_webdatasets, out_dir, bc, encoder, ) save_interval = bc['save_every_n_batches'] + evaluate_interval = bc['evaluate_every_n_batches'] + n_evaluation_rollouts = bc['n_evaluation_rollouts'] + callbacks = [] if save_interval is not None: optional_model_saver = BCModelSaver(policy, os.path.join(out_dir, 'snapshots'), bc['nominal_length'], save_interval) + callbacks.append(optional_model_saver) + if evaluate_interval is not None: + optional_model_evaluator = IntermediateRolloutEvaluator(policy, + venv_chans_first, + os.path.join(out_dir, 'evaluations'), + bc['nominal_length'], + evaluate_interval, + n_evaluation_rollouts) + callbacks.append(optional_model_evaluator) + + if callbacks == []: + callback_op = None else: - optional_model_saver = None + callback_op = MultiCallback(callbacks) logging.info("Beginning BC training") trainer.train(n_epochs=None, n_batches=bc['n_batches'], log_interval=bc['log_interval'], - on_epoch_end=optional_model_saver) + on_epoch_end=callback_op) final_path = os.path.join(out_dir, final_pol_name) logging.info(f"Saving final BC policy to {final_path}") trainer.save_policy(final_path) + + if save_output_encoder: + final_encoder_path = os.path.join(out_dir, output_encoder_path) + logging.info(f"Saving encoder component of final BC policy to {final_encoder_path}") + th.save(policy.features_extractor.representation_encoder, final_encoder_path) return final_path @@ -288,7 +345,6 @@ def do_training_gail( action_space=venv_chans_first.action_space, encoder=encoder, ) - def policy_constructor(observation_space, action_space, lr_schedule, @@ -296,7 +352,8 @@ def policy_constructor(observation_space, """Construct a policy with the right LR schedule (since PPO will actually use it, unlike BC).""" assert not use_sde - return make_policy(observation_space=observation_space, + return make_policy(venv=venv_chans_first, + observation_space=observation_space, action_space=action_space, encoder_or_path=encoder, lr_schedule=lr_schedule) diff --git a/src/il_representations/scripts/mkdataset_demos.py b/src/il_representations/scripts/mkdataset_demos.py index 8a543695..a733923f 100644 --- a/src/il_representations/scripts/mkdataset_demos.py +++ b/src/il_representations/scripts/mkdataset_demos.py @@ -32,21 +32,24 @@ def default_config(): shuffle_traj_order = True # put an upper limit on number of trajectories to load n_traj_total = None - # TODO(sam): support sharding + # TODO maybe implement for other envs than Minecraft? + frames_per_traj = None + # TODO(sam): support sharding + data_type = "demos" _ = locals() del _ @mkdataset_demos_ex.main -def run(seed, env_data, env_cfg, shuffle_traj_order, n_traj_total): +def run(seed, env_data, env_cfg, shuffle_traj_order, n_traj_total, frames_per_traj, data_type): set_global_seeds(seed) # python built-in logging logging.basicConfig(level=logging.INFO) # load existing demo dictionary directly, w/ same code used to handle data # in il_train.py - dataset_dict = auto_env.load_dict_dataset(n_traj=n_traj_total) + dataset_dict = auto_env.load_dict_dataset(n_traj=n_traj_total, frames_per_traj=frames_per_traj) n_samples = len(dataset_dict['obs']) # keys in dataset_dict: 'obs', 'next_obs', 'acts', 'infos', 'rews', 'dones' # numeric_types = (np.ndarray, numbers.Number, np.bool_) @@ -86,7 +89,7 @@ def run(seed, env_data, env_cfg, shuffle_traj_order, n_traj_total): out_file_path = os.path.join( auto_env.get_data_dir(benchmark_name=env_cfg['benchmark_name'], task_key=env_cfg['task_name'], - data_type='demos'), 'demos.tgz') + data_type=data_type), 'demos.tgz') # get metadata for the dataset meta_dict = get_meta_dict() @@ -94,7 +97,6 @@ def run(seed, env_data, env_cfg, shuffle_traj_order, n_traj_total): if shuffle_traj_order: # write trajectories in random order random.shuffle(trajectories) - def frame_gen(): for traj_num, traj in enumerate(trajectories): traj_len = len(traj['obs']) diff --git a/src/il_representations/scripts/mkdataset_random.py b/src/il_representations/scripts/mkdataset_random.py index fd3cc137..bc6fc365 100644 --- a/src/il_representations/scripts/mkdataset_random.py +++ b/src/il_representations/scripts/mkdataset_random.py @@ -89,6 +89,9 @@ def frame_iter(): # yield a dictionary for each frame in the retrieved # trajectories for idx in range(T): + if dones[idx]: + if 'rollout' in traj.infos[idx]: + del traj.infos[idx]['rollout'] yield { # Keys in dataset_dict: 'obs', 'next_obs', 'acts', # 'infos', 'rews', 'dones'. diff --git a/src/il_representations/scripts/render_dataset.py b/src/il_representations/scripts/render_dataset.py new file mode 100644 index 00000000..3a2cf1f8 --- /dev/null +++ b/src/il_representations/scripts/render_dataset.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +"""Loads a webdataset and renders it into images, while printing out some +debugging info. Useful for verifying that the file contains what you think it +contains!""" +import itertools as it +import logging +import os +import pprint +from typing import Any, Dict, List, Optional + +import numpy as np +from PIL import Image +import sacred +from sacred import Experiment + +from il_representations.algos.utils import set_global_seeds +from il_representations.envs import auto +from il_representations.envs.config import (env_cfg_ingredient, + env_data_ingredient) +from il_representations.utils import NUM_CHANS + +sacred.SETTINGS['CAPTURE_MODE'] = 'no' # workaround for sacred issue#740 +render_dataset_ex = Experiment( + 'render_dataset', + ingredients=[ + env_cfg_ingredient, env_data_ingredient + ]) + + +@render_dataset_ex.config +def default_config(): + # number of trajectories to illustrate + n_traj = None + # number of frames to write per trajectory (default: all of them) + # (if more frames are specified than the length of the trajectory, then + # some will be repeated) + frames_per_traj = 10 + # where to write output? + out_dir = None + # config to load + dataset_config = {'type': 'demos'} + # when dealing with frame stacks, drop all but the latest frame + keep_only_latest = False + # size of border around images + border_size = 4 + + _ = locals() + del _ + + +@render_dataset_ex.named_config +def random_data(): + dataset_config = {'type': 'random'} + _ = locals() + del _ + +def trajectory_iter(dataset): + """Yields one trajectory at a time from a webdataset.""" + traj = [] + ind = 0 + for frame in dataset: + traj.append(frame) + print(f"Appended frame {ind}") + ind += 1 + if frame['dones']: + yield traj + traj = [] + + +def sample_points(traj_len: int, n_points: Optional[int]=None) -> np.ndarray: + """Collect `n_points` indices into array, spaced ~evenly (or just return + all points, if n_points is None).""" + if n_points is None: + return np.arange(traj_len) + lin_samples = np.linspace(0, traj_len - 1, n_points) + rounded = np.round(lin_samples) + return rounded.astype('int64') + + +def concat_traj(traj: List[Dict[str, Any]]) -> Dict[str, np.ndarray]: + """Combine the per-step dictionaries that make up a trajectory into a + single dictionary that maps keys to concatenated values.""" + frame0: Dict[str, Any] = traj[0] + keys_to_stack: List[str] = [] + for key, value in frame0.items(): + if isinstance(value, np.ndarray): + keys_to_stack.append(key) + # for some reason using a dict comprehension here was confusing pytype + # (2020.08.10) + rv_dict = {} + for key in keys_to_stack: + stacked = np.stack([f[key] for f in traj], axis=0) + rv_dict[key] = stacked + return rv_dict + + +def get_n_chans() -> int: + return NUM_CHANS[auto.load_color_space()] + + +def simplify_stacks(obs_vec: np.ndarray, keep_only_latest: bool) -> np.ndarray: + # simple sanity checks to make sure frames are N*(C*H)*W + assert obs_vec.ndim == 4, f"obs_vec.shape={obs_vec.shape}, so ndim != 4" + if obs_vec.shape[-1] != obs_vec.shape[-2]: + logging.warn( + f"obs_vec.shape={obs_vec.shape} does not look N(C*F)HW, " + "since H!=W") + n_chans = get_n_chans() + stack_len = obs_vec.shape[1] // n_chans + assert stack_len * n_chans == obs_vec.shape[1], \ + f"obs_vec.shape={obs_vec.shape} should be N(C*F)HW, "\ + f"but first dim is not divisible by n_chans={n_chans}" + new_shape = obs_vec.shape[:1] + (stack_len, n_chans) + obs_vec.shape[2:] + destacked = np.reshape(obs_vec, new_shape) + # put stack dimension first + transposed = np.transpose(destacked, (1, 0, 2, 3, 4)) + if keep_only_latest: + final_obs_vec = transposed[-1] + else: + final_obs_vec = np.concatenate(transposed, axis=3) + # now it's actually N*C*H*W', where W' has absorbed all the stacked frames + # from before + return final_obs_vec + + +def to_film_strip(images: np.ndarray, border_size: int=1) -> np.ndarray: + """Convert an N*C*H*W array of image frames into a horizontal 'film strip' + with a black border of `border_size` separating the frames (as will as a + border on the outsides).""" + # make a big array to hold all the images + n_images, n_chans, height, width = images.shape + out_array_size = (n_chans, 2 * border_size + height, + n_images * width + (n_images + 1) * border_size) + out_array = np.zeros(out_array_size, dtype=images.dtype) + for idx, imag in enumerate(images): + h_start = border_size + h_stop = h_start + imag.shape[1] + w_start = border_size * (idx + 1) + width * idx + w_stop = w_start + imag.shape[2] + out_array[:, h_start:h_stop, w_start:w_stop] = imag + return out_array + + +def save_obs_as_film(obs: np.ndarray, dest: str, keep_only_latest: bool, + border_size: int, frames_per_traj: int) -> None: + """Save a list of observations in N*(C*F)*H*W format into a file, after + converting to a 'film strip' (appropriate for representing, e.g., a + continuous trajectory).""" + d = os.path.dirname(dest) + if d: + os.makedirs(d, exist_ok=True) + simple_indices = sample_points(len(obs), frames_per_traj) + obs = obs[simple_indices] + images = simplify_stacks(obs, keep_only_latest=keep_only_latest) + film = to_film_strip(images, border_size=border_size) + film_hwc = np.transpose(film, (1, 2, 0)) + try: + pil_image = Image.fromarray(film_hwc) + except TypeError: + # Minecraft is being saved in float (0, 1) format... maybe fi + pil_image = Image.fromarray((film_hwc * 255).astype(np.uint8)) + + pil_image.save(dest) + + +@render_dataset_ex.main +def run(n_traj: int, frames_per_traj: int, out_dir: str, dataset_config: dict, + keep_only_latest: bool, border_size: int, seed: int) -> None: + set_global_seeds(seed) + logging.getLogger().setLevel(logging.INFO) + + print(f'Supplied dataset config:') + pprint.pprint(dataset_config) + + # we only support loading one dataset (hence the [dataset_config] thing) + (webdataset, ), combined_meta = auto.load_wds_datasets( + configs=[dataset_config]) + + print(f"Collected metadata from loaded dataset:") + pprint.pprint(combined_meta) + + # now write same trajectories to out_dir + os.makedirs(out_dir, exist_ok=True) + trajectories = it.islice(trajectory_iter(webdataset), n_traj) + for idx, trajectory in enumerate(trajectories): + breakpoint() + print(f"Hit trajectory {idx}") + traj_dict = concat_traj(trajectory) + num_str = f'{idx:06d}' + for key in ('obs', 'next_obs'): + save_obs_as_film( + traj_dict[key], + os.path.join(out_dir, f'{key}_{num_str}.png'), + keep_only_latest=keep_only_latest, + border_size=border_size, + frames_per_traj=frames_per_traj) + + +if __name__ == '__main__': + render_dataset_ex.run_commandline() \ No newline at end of file diff --git a/src/il_representations/test_support/configuration.py b/src/il_representations/test_support/configuration.py index 693bfcc4..5d6acd69 100644 --- a/src/il_representations/test_support/configuration.py +++ b/src/il_representations/test_support/configuration.py @@ -4,7 +4,7 @@ from ray import tune from il_representations import algos - +from il_representations.envs.auto import benchmark_is_available CURRENT_DIR = path.dirname(path.abspath(__file__)) TEST_DATA_DIR = path.abspath( path.join(CURRENT_DIR, '..', '..', '..', 'tests', 'data')) @@ -31,14 +31,24 @@ { 'benchmark_name': 'dm_control', 'task_name': 'reacher-easy', - }, - { - 'benchmark_name': 'minecraft', - 'task_name': 'NavigateVectorObf', - 'minecraft_max_env_steps': 100 } - ] + +if benchmark_is_available('minecraft'): + # Doing this this way because we need to import configuration elements from RB + from realistic_benchmarks import wrappers as rb_wrappers + from il_representations.envs.utils import MinecraftPOVWrapper, Testing2500StepLimitWrapper + ENV_CFG_TEST_CONFIGS.append( + { + 'benchmark_name': 'minecraft', + 'task_name': 'NavigateVectorObf', + 'minecraft_max_env_steps': 100, + 'minecraft_wrappers': [rb_wrappers.ActionFlatteningWrapper, MinecraftPOVWrapper, Testing2500StepLimitWrapper] + } + ) + + + FAST_IL_TRAIN_CONFIG = { 'bc': { 'n_batches': 1, diff --git a/tests/data/processed/demos/minecraft/NavigateVectorObf/demos.tgz b/tests/data/processed/demos/minecraft/NavigateVectorObf/demos.tgz index f8a298d0..b8d3259e 100644 Binary files a/tests/data/processed/demos/minecraft/NavigateVectorObf/demos.tgz and b/tests/data/processed/demos/minecraft/NavigateVectorObf/demos.tgz differ diff --git a/tests/test_il_train_test.py b/tests/test_il_train_test.py index 4acf6fde..5e26c979 100644 --- a/tests/test_il_train_test.py +++ b/tests/test_il_train_test.py @@ -36,7 +36,6 @@ def test_il_train_test(env_cfg, algo, il_train_ex, il_test_ex, # FIXME(sam): same comment as elsewhere: should have a better way of # getting at saved policies. log_dir = file_observer.dir - # test policy_path = os.path.join(log_dir, final_pol_name) il_test_ex.run(