diff --git a/ss2r/algorithms/sbsrl/losses.py b/ss2r/algorithms/sbsrl/losses.py index 481713c85..1c5dfc491 100644 --- a/ss2r/algorithms/sbsrl/losses.py +++ b/ss2r/algorithms/sbsrl/losses.py @@ -55,6 +55,8 @@ def make_losses( offline, flip_uncertainty_constraint, target_entropy: float | None = None, + separate_critics: bool = False, + optimistic_qr: bool = False, ): target_entropy = -0.5 * action_size if target_entropy is None else target_entropy policy_network = sbsrl_network.policy_network @@ -216,6 +218,7 @@ def actor_loss( safety_budget: float, penalizer: Penalizer | None, penalizer_params: Any, + backup_qc_params: Params | None, ) -> jnp.ndarray: dist_params = policy_network.apply( normalizer_params, policy_params, transitions.observation @@ -243,20 +246,45 @@ def actor_loss( qr = jnp.min(qr_action, axis=-1) qr /= reward_scaling actor_loss = -qr.mean() + if optimistic_qr: + actor_loss = -jnp.max(jnp.mean(qr, axis=-1)) exploration_loss = (alpha * log_prob).mean() aux = {} if safe or uncertainty_constraint: assert qc_network is not None - qc_action = jax.vmap( - lambda i: qc_network.apply( - normalizer_params, - qc_params, - transitions.observation, - action, - jnp.full((transitions.observation.shape[0],), i, dtype=jnp.int32), + if separate_critics: + qc_action = jax.vmap( + lambda i, p: qc_network.apply( + normalizer_params, + p, + transitions.observation, + action, + jnp.full( + (transitions.observation.shape[0],), i, dtype=jnp.int32 + ), + ), + in_axes=(0, 0), + )(idxs, qc_params) + else: + qc_action = jax.vmap( + lambda i: qc_network.apply( + normalizer_params, + qc_params, + transitions.observation, + action, + jnp.full( + (transitions.observation.shape[0],), i, dtype=jnp.int32 + ), + ) + )(idxs) # (E, B, n_critics*head_size) + if save_sooper_backup: + assert backup_qc_network is not None + qc_backup = backup_qc_network.apply( + normalizer_params, backup_qc_params, transitions.observation, action ) - )(idxs) # (E, B, n_critics*head_size) + qc_backup = jnp.mean(qc_backup) + aux["qc_backup"] = qc_backup qc_action = qc_action.reshape( ensemble_size, -1, n_critics, int(safe) + int(uncertainty_constraint) ) # -> (E, B, n_critics, head_size) @@ -420,6 +448,7 @@ def policy(obs: jax.Array) -> tuple[jax.Array, jax.Array]: return ( alpha_loss, critic_loss_vmap, + critic_loss, actor_loss, compute_model_loss, backup_critic_loss, diff --git a/ss2r/algorithms/sbsrl/networks.py b/ss2r/algorithms/sbsrl/networks.py index e6a5b33ff..330b5a74e 100644 --- a/ss2r/algorithms/sbsrl/networks.py +++ b/ss2r/algorithms/sbsrl/networks.py @@ -194,6 +194,61 @@ def apply(processor_params, q_params, obs, actions, idx): ) +def make_q_network_separate( + obs_size: types.ObservationSize, + action_size: int, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, + hidden_layer_sizes: Sequence[int] = (256, 256), + activation: ActivationFn = linen.relu, + n_critics: int = 2, + obs_key: str = "state", + use_bro: bool = True, + n_heads: int = 1, + head_size: int = 1, + ensemble_size: int = 10, + embedding_dim: int = 4, +) -> networks.FeedForwardNetwork: + """Creates a value network.""" + + class QModule(linen.Module): + """Q Module.""" + + n_critics: int + + @linen.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray, idx: jnp.ndarray): + hidden = jnp.concatenate([obs, actions], axis=-1) + + res = [] + net = BroNet if use_bro else MLP + for _ in range(self.n_critics): + q = net( # type: ignore + layer_sizes=list(hidden_layer_sizes) + [head_size], + activation=activation, + kernel_init=jax.nn.initializers.lecun_uniform(), + num_heads=n_heads, + )(hidden) + res.append(q) + return jnp.concatenate(res, axis=-1) + + q_module = QModule(n_critics=n_critics) + + def apply(processor_params, q_params, obs, actions, idx): + obs = preprocess_observations_fn(obs, processor_params) + obs = obs if isinstance(obs, jax.Array) else obs[obs_key] + idx = jnp.asarray(idx, dtype=jnp.int32) + return q_module.apply(q_params, obs, actions, idx) + + obs_size = _get_obs_state_size(obs_size, obs_key) + dummy_obs = jnp.zeros((1, obs_size)) + dummy_action = jnp.zeros((1, action_size)) + dummy_idx = jnp.zeros((1,), dtype=jnp.int32) + return networks.FeedForwardNetwork( + init=lambda key: q_module.init(key, dummy_obs, dummy_action, dummy_idx), + apply=apply, + ) + + def make_sbsrl_networks( observation_size: types.ObservationSize, action_size: int, @@ -213,6 +268,7 @@ def make_sbsrl_networks( save_sooper_backup: bool = False, ensemble_size: int = 10, embedding_dim: int = 4, + separate_critics: bool = False, ) -> SBSRLNetworks: parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size @@ -240,20 +296,36 @@ def make_sbsrl_networks( ) if safe or uncertainty_constraint: n_outputs_qc = int(safe) + int(uncertainty_constraint) - qc_network = make_q_network_ensemble( - observation_size, - action_size, - preprocess_observations_fn=preprocess_observations_fn, - hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation, - obs_key=value_obs_key, - use_bro=use_bro, - n_critics=n_critics, - n_heads=n_heads, - ensemble_size=ensemble_size, - embedding_dim=embedding_dim, - head_size=n_outputs_qc, - ) + if separate_critics: + qc_network = make_q_network_separate( + observation_size, + action_size, + preprocess_observations_fn=preprocess_observations_fn, + hidden_layer_sizes=value_hidden_layer_sizes, + activation=activation, + obs_key=value_obs_key, + use_bro=use_bro, + n_critics=n_critics, + n_heads=n_heads, + ensemble_size=ensemble_size, + embedding_dim=embedding_dim, + head_size=n_outputs_qc, + ) + else: + qc_network = make_q_network_ensemble( + observation_size, + action_size, + preprocess_observations_fn=preprocess_observations_fn, + hidden_layer_sizes=value_hidden_layer_sizes, + activation=activation, + obs_key=value_obs_key, + use_bro=use_bro, + n_critics=n_critics, + n_heads=n_heads, + ensemble_size=ensemble_size, + embedding_dim=embedding_dim, + head_size=n_outputs_qc, + ) old_apply = qc_network.apply qc_network.apply = lambda *args, **kwargs: jnn.softplus( old_apply(*args, **kwargs) diff --git a/ss2r/algorithms/sbsrl/on_policy_training_step.py b/ss2r/algorithms/sbsrl/on_policy_training_step.py index f5f46d7e2..7552ec3a6 100644 --- a/ss2r/algorithms/sbsrl/on_policy_training_step.py +++ b/ss2r/algorithms/sbsrl/on_policy_training_step.py @@ -60,6 +60,8 @@ def make_on_policy_training_step( ensemble_size, sac_batch_size, normalize_fn, + separate_critics, + ensemble_index, ) -> TrainingStepFn: def split_transitions_ensemble( transitions: Transition, ensemble_axis: int = 1 @@ -99,7 +101,7 @@ def _reshape_leaf(x, name): return trans_per_ens def compress_transitions_ensemble( - transitions: Transition, ensemble_axis: int = 1 + transitions: Transition, ensemble_index, ensemble_axis: int = 1 ) -> Transition: def _reduce_leaf(x: Any, name: str): if isinstance(x, dict): @@ -107,6 +109,8 @@ def _reduce_leaf(x: Any, name: str): x_arr = jnp.asarray(x) if name in ("observation", "action"): return x_arr + if ensemble_index != -1: + return jnp.take(x_arr, ensemble_index, axis=ensemble_axis) return jnp.mean(x_arr, axis=ensemble_axis) replacements = { @@ -159,7 +163,7 @@ def sgd_step( ) if save_sooper_backup: compressed_transitions = compress_transitions_ensemble( - transitions, ensemble_axis=1 + transitions, ensemble_index, ensemble_axis=1 ) ( backup_critic_loss, @@ -167,7 +171,7 @@ def sgd_step( backup_qr_optimizer_state, ) = backup_critic_update( training_state.backup_qr_params, - training_state.behavior_policy_params, # TODO: Is it correct to use a common policy for backup and behavior? + training_state.behavior_policy_params, training_state.normalizer_params, training_state.backup_target_qr_params, alpha, @@ -188,24 +192,55 @@ def sgd_step( key, key_cost = jax.random.split(key) _, *ens_keys_cost = jax.random.split(key_cost, ensemble_size + 1) ens_keys_cost = jnp.stack(ens_keys_cost) - ( - cost_loss, - behavior_qc_params, - behavior_qc_optimizer_state, - ) = cost_critic_update( - training_state.behavior_qc_params, - training_state.behavior_policy_params, - training_state.normalizer_params, - training_state.behavior_target_qc_params, - alpha, - trans_per_ens, - ens_keys_cost, - cost_q_transform, - safe, - uncertainty_constraint, - optimizer_state=training_state.behavior_qc_optimizer_state, - params=training_state.behavior_qc_params, - ) + + if not separate_critics: + ( + cost_loss, + behavior_qc_params, + behavior_qc_optimizer_state, + ) = cost_critic_update( + training_state.behavior_qc_params, + training_state.behavior_policy_params, + training_state.normalizer_params, + training_state.behavior_target_qc_params, + alpha, + trans_per_ens, + ens_keys_cost, + cost_q_transform, + safe, + uncertainty_constraint, + optimizer_state=training_state.behavior_qc_optimizer_state, + params=training_state.behavior_qc_params, + ) + else: + per_member_vmap = jax.vmap( + lambda p_i, opt_i, trans_i, key_i, tq_i: cost_critic_update( + p_i, + training_state.behavior_policy_params, + training_state.normalizer_params, + tq_i, + alpha, + trans_i, + key_i, + cost_q_transform, + safe, + uncertainty_constraint, + optimizer_state=opt_i, + params=p_i, + ), + in_axes=(0, 0, 0, 0, 0), + ) + ( + cost_loss, + behavior_qc_params, + behavior_qc_optimizer_state, + ) = per_member_vmap( + training_state.behavior_qc_params, + training_state.behavior_qc_optimizer_state, + trans_per_ens, + ens_keys_cost, + training_state.behavior_target_qc_params, + ) cost_metrics["behavior_cost_critic_loss"] = cost_loss else: behavior_qc_params = training_state.behavior_qc_params @@ -251,6 +286,7 @@ def sgd_step( safety_budget, penalizer, training_state.penalizer_params, + training_state.backup_qc_params, optimizer_state=training_state.behavior_policy_optimizer_state, params=training_state.behavior_policy_params, ) @@ -420,9 +456,10 @@ def generate_model_data( transitions.reward.shape, disagreement ) # (B,ensemble_size) disagreement_metrics = {"normalized_disagreement": disagreement.mean()} - disagreement_metrics["model_stage_cost"] = transitions.extras[ - "state_extras" - ]["cost"] + if safe: + disagreement_metrics["model_stage_cost"] = transitions.extras[ + "state_extras" + ]["cost"] sac_replay_buffer_state = sac_replay_buffer.insert( sac_replay_buffer_state, float16(transitions) ) @@ -544,7 +581,8 @@ def training_step( training_key = key model_buffer_state, transitions = model_replay_buffer.sample(model_buffer_state) cost_metrics = {} - cost_metrics["real_stage_cost"] = transitions.extras["state_extras"]["cost"] + if safe: + cost_metrics["real_stage_cost"] = transitions.extras["state_extras"]["cost"] # Change the front dimension of transitions so 'update_step' is called # grad_updates_per_step times by the scan. tmp_transitions = jax.tree_util.tree_map( diff --git a/ss2r/algorithms/sbsrl/train.py b/ss2r/algorithms/sbsrl/train.py index f725f5479..a4358808e 100644 --- a/ss2r/algorithms/sbsrl/train.py +++ b/ss2r/algorithms/sbsrl/train.py @@ -112,6 +112,7 @@ def _init_training_state( model_ensemble_size: int, embedding_dim: int, penalizer_params: Params | None, + separate_critics: bool, ) -> TrainingState: """Inits the training state and replicates it over devices.""" key_policy, key_qr, key_model = jax.random.split(key, 3) @@ -126,9 +127,19 @@ def _init_training_state( model_params = init_model_ensemble(model_keys) model_optimizer_state = model_optimizer.init(model_params) if sbsrl_network.qc_network is not None: - behavior_qc_params = sbsrl_network.qc_network.init(key_qr) - assert qc_optimizer is not None - behavior_qc_optimizer_state = qc_optimizer.init(behavior_qc_params) + if separate_critics: + keys = jax.random.split(key_qr, model_ensemble_size) + behavior_qc_params = jax.vmap(lambda k: sbsrl_network.qc_network.init(k))( + keys + ) + assert qc_optimizer is not None + behavior_qc_optimizer_state = jax.vmap(lambda p: qc_optimizer.init(p))( + behavior_qc_params + ) + else: + behavior_qc_params = sbsrl_network.qc_network.init(key_qr) + assert qc_optimizer is not None + behavior_qc_optimizer_state = qc_optimizer.init(behavior_qc_params) else: behavior_qc_params = None behavior_qc_optimizer_state = None @@ -259,6 +270,10 @@ def train( learn_from_scratch: bool = False, target_entropy: float | None = None, pessimistic_q: bool = False, + separate_critics: bool = False, + load_data: bool = False, + optimistic_qr: bool = False, + ensemble_index: int = -1, ): if min_replay_size >= num_timesteps: raise ValueError( @@ -393,6 +408,7 @@ def normalize_leaf(data: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: model_ensemble_size=model_ensemble_size, embedding_dim=embedding_dim, penalizer_params=penalizer_params, + separate_critics=separate_critics, ) del global_key local_key, model_rb_key, actor_critic_rb_key, env_key, eval_key = jax.random.split( @@ -491,6 +507,8 @@ def normalize_leaf(data: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: else None, ) else: + if load_data: + model_buffer_state = replay_buffers.ReplayBufferState(**params[-1]) policy_optimizer_state = restore_state( params[6][1]["inner_state"] if isinstance(params[6][1], dict) @@ -551,6 +569,7 @@ def normalize_leaf(data: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: ( alpha_loss, critic_loss, + critic_loss_separate, actor_loss, model_loss, backup_critic_loss, @@ -576,6 +595,8 @@ def normalize_leaf(data: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: offline=offline, flip_uncertainty_constraint=flip_uncertainty_constraint, target_entropy=target_entropy, + separate_critics=separate_critics, + optimistic_qr=optimistic_qr, ) alpha_update = ( gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray @@ -588,9 +609,14 @@ def normalize_leaf(data: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: ) ) if safe or uncertainty_constraint: - cost_critic_update = gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray - critic_loss, qc_optimizer, pmap_axis_name=None - ) + if separate_critics: + cost_critic_update = gradients.gradient_update_fn( + critic_loss_separate, qc_optimizer, pmap_axis_name=None + ) + else: + cost_critic_update = gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray + critic_loss, qc_optimizer, pmap_axis_name=None + ) else: cost_critic_update = None model_update = ( @@ -673,6 +699,8 @@ def normalize_leaf(data: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: model_ensemble_size, sac_batch_size, disagreement_normalize_fn, + separate_critics, + ensemble_index, ) def prefill_replay_buffer( @@ -838,7 +866,7 @@ def training_epoch_with_timing( # Create and initialize the replay buffer. t = time.time() prefill_key, local_key = jax.random.split(local_key) - if not offline: + if not (offline or load_data): training_state, env_state, model_buffer_state, _ = prefill_replay_buffer( training_state, env_state, model_buffer_state, prefill_key ) diff --git a/ss2r/configs/agent/sbsrl.yaml b/ss2r/configs/agent/sbsrl.yaml index 1387583ee..e6915ba65 100644 --- a/ss2r/configs/agent/sbsrl.yaml +++ b/ss2r/configs/agent/sbsrl.yaml @@ -57,6 +57,11 @@ training_step_fn: on_policy uncertainty_constraint: false uncertainty_epsilon: 0.0 use_mean_critic: false +use_mean_dynamics: false +ensemble_index: -1 use_max_critic: false load_disagreement_normalizer: false -pessimistic_cost: false \ No newline at end of file +pessimistic_cost: false +separate_critics: false +load_data: false +optimistic_qr: false \ No newline at end of file diff --git a/ss2r/configs/experiment/cartpole_swingup_sbsrl_ep2.yaml b/ss2r/configs/experiment/cartpole_swingup_sbsrl_ep2.yaml new file mode 100644 index 000000000..96ca8f96c --- /dev/null +++ b/ss2r/configs/experiment/cartpole_swingup_sbsrl_ep2.yaml @@ -0,0 +1,45 @@ +# @package _global_ +defaults: + - override /environment: cartpole_swingup + - override /agent: sbsrl + - override /agent/data_collection: episodic + - override /agent/penalizer: multiaug_lagrangian + - _self_ + +training: + num_timesteps: 2500000 + action_repeat: 4 + safe: true + train_domain_randomization: false + eval_domain_randomization: false + safety_budget: 100 + num_envs: 20 + num_evals: 15 + num_eval_episodes: 10 + wandb_id: tjabwpzf #,4ne87e1s,651uyrm2,ytjenkjh,10yjpvr3 + +agent: + policy_hidden_layer_sizes: [256, 256, 256] + value_hidden_layer_sizes: [512, 512] + activation: swish + batch_size: 256 + min_replay_size: 8192 #1000 + max_replay_size: 1548576 + critic_grad_updates_per_step: 2000 + model_grad_updates_per_step: 80000 + num_model_rollouts: 100000 + learning_rate: 3e-4 + critic_learning_rate: 3e-4 + model_learning_rate: 3e-4 + uncertainty_constraint: true + uncertainty_epsilon: 0.0 + use_mean_critic: false + separate_critics: true + model_to_real_data_ratio: 1 + optimistic_qr: true + penalizer: + lagrange_multiplier: 2.5 + penalty_multiplier: 1e-3 + penalty_multiplier_factor: 1e-4 + lagrange_multiplier_sigma: 0 + diff --git a/ss2r/configs/experiment/cartpole_swingup_simple_sbsrl_offline.yaml b/ss2r/configs/experiment/cartpole_swingup_simple_sbsrl_offline.yaml index 1867b31a3..23a4cad5e 100644 --- a/ss2r/configs/experiment/cartpole_swingup_simple_sbsrl_offline.yaml +++ b/ss2r/configs/experiment/cartpole_swingup_simple_sbsrl_offline.yaml @@ -35,10 +35,11 @@ agent: uncertainty_epsilon: 0 cost_scaling: 1 reward_scaling: 1 - model_to_real_data_ratio: 0.5 - reward_pessimism: 100 + model_to_real_data_ratio: 1 + reward_pessimism: 0 use_mean_critic: false + separate_critics: true penalizer: - lagrange_multiplier: 0.1 + lagrange_multiplier: 0.5 penalty_multiplier: 0.001 penalty_multiplier_factor: 1e-4 diff --git a/ss2r/configs/experiment/cartpole_swingup_sparse_ep.yaml b/ss2r/configs/experiment/cartpole_swingup_sparse_ep.yaml new file mode 100644 index 000000000..faae95c78 --- /dev/null +++ b/ss2r/configs/experiment/cartpole_swingup_sparse_ep.yaml @@ -0,0 +1,37 @@ +# @package _global_ +defaults: + - override /environment: cartpole_swingup + - override /agent: sbsrl + - override /agent/data_collection: episodic + - override /agent/penalizer: multiaug_lagrangian + - _self_ + +environment: + task_name: HardCartpoleSwingupSparse + task_params: + action_cost_scale: 0.2 + +training: + num_timesteps: 2000000 + safe: false + num_envs: 20 + train_domain_randomization: false + eval_domain_randomization: false + action_repeat: 1 + +agent: + uncertainty_constraint: true + uncertainty_epsilon: 75 + model_to_real_data_ratio: 0.25 + min_replay_size: 1000 + sac_batch_size: 256 + critic_grad_updates_per_step: 4000 + model_grad_updates_per_step: 500 + num_model_rollouts: 100000 + num_critic_updates_per_actor_update: 1 + optimistic_qr: true + normalize_disagreement: true + penalizer: + lagrange_multiplier_sigma: 0.1 + penalty_multiplier: 1e-3 + penalty_multiplier_factor: 1e-4 diff --git a/ss2r/configs/experiment/go_to_goal_easy_sbsrl_ep.yaml b/ss2r/configs/experiment/go_to_goal_easy_sbsrl_ep.yaml new file mode 100644 index 000000000..a972a1690 --- /dev/null +++ b/ss2r/configs/experiment/go_to_goal_easy_sbsrl_ep.yaml @@ -0,0 +1,49 @@ +# @package _global_ +defaults: + - override /environment: go_to_goal_easy + - override /agent: sbsrl + - override /agent/data_collection: episodic + - override /agent/penalizer: multiaug_lagrangian + - _self_ + +training: + num_timesteps: 200000 + train_domain_randomization: false + eval_domain_randomization: false + safe: true + safety_budget: 25 + action_repeat: 4 + num_envs: 1 + num_evals: 20 + wandb_id: be5glx0u #,naysr4ow,p14k5175,0mzh8epy,km2chbf8 + +agent: + policy_hidden_layer_sizes: [256, 256, 256] + value_hidden_layer_sizes: [512, 512] + activation: swish + batch_size: 256 + max_replay_size: 4194304 + critic_grad_updates_per_step: 2000 + model_grad_updates_per_step: 130000 + num_critic_updates_per_actor_update: 3 + num_model_rollouts: 100000 + learning_rate: 3e-6 + critic_learning_rate: 1e-7 + model_learning_rate: 1e-4 + + safety_discounting: 0.999 + normalize_budget: false + reward_scaling: 1 + cost_scaling: 1 + uncertainty_constraint: true + uncertainty_epsilon: 0 #0 + model_to_real_data_ratio: 1 + use_mean_critic: false + optimistic_qr: false + separate_critics: true + penalizer: + lagrange_multiplier: 0.01 + penalty_multiplier: 0.001 + penalty_multiplier_factor: 1e-4 + lagrange_multiplier_sigma: 0 + diff --git a/ss2r/configs/experiment/go_to_goal_easy_sbsrl_offline.yaml b/ss2r/configs/experiment/go_to_goal_easy_sbsrl_offline.yaml index 33fdfd6b2..38469ef87 100644 --- a/ss2r/configs/experiment/go_to_goal_easy_sbsrl_offline.yaml +++ b/ss2r/configs/experiment/go_to_goal_easy_sbsrl_offline.yaml @@ -34,9 +34,10 @@ agent: cost_scaling: 1 uncertainty_constraint: true uncertainty_epsilon: 0 - model_to_real_data_ratio: 0.5 + model_to_real_data_ratio: 1 reward_pessimism: 0. use_mean_critic: false + separate_critics: true penalizer: lagrange_multiplier: 0.1 penalty_multiplier: 0.001 diff --git a/ss2r/configs/experiment/humanoid_walk_sbsrl_lagrangian.yaml b/ss2r/configs/experiment/humanoid_walk_sbsrl_lagrangian.yaml index a820a8e9b..5e116cc61 100644 --- a/ss2r/configs/experiment/humanoid_walk_sbsrl_lagrangian.yaml +++ b/ss2r/configs/experiment/humanoid_walk_sbsrl_lagrangian.yaml @@ -2,8 +2,7 @@ defaults: - override /environment: humanoid_walk - override /agent: sbsrl - - override /agent/data_collection: episodic #TODO: check - - override /agent/cost_robustness: pessimistic_cost_update + - override /agent/data_collection: episodic - override /agent/penalizer: multiaug_lagrangian - _self_ @@ -16,7 +15,7 @@ training: safety_budget: 100 num_envs: 1 num_evals: 20 - wandb_id: z3518n9m + wandb_id: 13rni74f #,jdm2hwk3,vv3cfif1,pecdgnor,35hhiyfg agent: policy_hidden_layer_sizes: [256, 256, 256] @@ -41,10 +40,11 @@ agent: use_mean_critic: false reward_pessimism: 0 cost_pessimism: 0 - cost_robustness: null safety_discounting: 0.999 normalize_budget: false + separate_critics: true + optimistic_qr: true penalizer: - lagrange_multiplier: 5 + lagrange_multiplier: 0.5 #range 0.01-2.5 penalty_multiplier: 0.001 penalty_multiplier_factor: 1e-4 diff --git a/ss2r/configs/experiment/humanoid_walk_sbsrl_offline.yaml b/ss2r/configs/experiment/humanoid_walk_sbsrl_offline.yaml index b24cba2bf..3130d60fd 100644 --- a/ss2r/configs/experiment/humanoid_walk_sbsrl_offline.yaml +++ b/ss2r/configs/experiment/humanoid_walk_sbsrl_offline.yaml @@ -32,12 +32,13 @@ agent: reward_scaling: 1 cost_scaling: 1 offline: true - reward_pessimism: 150 + reward_pessimism: 0 + separate_critics: true uncertainty_constraint: true uncertainty_epsilon: 0 - model_to_real_data_ratio: 0.5 + model_to_real_data_ratio: 0.75 use_mean_critic: false penalizer: - lagrange_multiplier: 0.1 + lagrange_multiplier: 5 penalty_multiplier: 0.001 penalty_multiplier_factor: 1e-4 diff --git a/ss2r/configs/experiment/rccar_sbsrl_ep.yaml b/ss2r/configs/experiment/rccar_sbsrl_ep.yaml new file mode 100644 index 000000000..54455cc0b --- /dev/null +++ b/ss2r/configs/experiment/rccar_sbsrl_ep.yaml @@ -0,0 +1,74 @@ +# @package _global_ +defaults: + - override /environment: rccar_real + - override /agent: sbsrl + + - override /agent/data_collection: episodic + - override /agent/penalizer: multiaug_lagrangian + - _self_ + +environment: + action_delay: 1 + observation_delay: 0 + sliding_window: 5 + dt: 0.03333333 + sample_init_pose: true + +training: + num_envs: 10 + num_timesteps: 200000 + episode_length: 250 + safe: true + train_domain_randomization: false + eval_domain_randomization: false + safety_budget: 5.0 + wandb_id: 2aim5ebr #,4fumttnh,zi3wgkie,t4fl23qe,4a86mb7h + +agent: + batch_size: 256 + min_replay_size: 8196 + max_replay_size: 1048576 + policy_hidden_layer_sizes: [64, 64] + critic_grad_updates_per_step: 4000 + model_grad_updates_per_step: 1000 + num_model_rollouts: 100000 + learning_rate: 1e-4 + critic_learning_rate: 1e-4 + model_learning_rate: 3e-4 + + + normalize_budget: true + use_mean_critic: false + uncertainty_constraint: true + flip_uncertainty_constraint: false + uncertainty_epsilon: 0 + cost_scaling: 1 + reward_scaling: 1 + model_to_real_data_ratio: 0.5 + reward_pessimism: 0 + cost_pessimism: 0 + normalize_disagreement: true + pessimistic_cost: true + separate_critics: true + optimistic_qr: true + + penalizer: + lagrange_multiplier: 0.01 + penalty_multiplier: 1e-3 + penalty_multiplier_factor: 1e-4 + + + + + + + + + + + + + + + + diff --git a/ss2r/configs/experiment/rccar_sbsrl_offline_short.yaml b/ss2r/configs/experiment/rccar_sbsrl_offline_short.yaml index ae9c49052..32b65f49c 100644 --- a/ss2r/configs/experiment/rccar_sbsrl_offline_short.yaml +++ b/ss2r/configs/experiment/rccar_sbsrl_offline_short.yaml @@ -14,13 +14,13 @@ environment: training: num_envs: 1 - num_timesteps: 25000 + num_timesteps: 40000 episode_length: 250 safe: true train_domain_randomization: false eval_domain_randomization: false safety_budget: 5.0 - wandb_id: fh46s5xs + wandb_id: 7168csru agent: batch_size: 256 @@ -43,9 +43,10 @@ agent: cost_scaling: 1 reward_scaling: 1 model_to_real_data_ratio: 0.5 - reward_pessimism: 10 + reward_pessimism: 0 cost_pessimism: 0. normalize_disagreement: true + separate_critics: true penalizer: lagrange_multiplier: 0.1 penalty_multiplier: 0.001