Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions ss2r/algorithms/sbsrl/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def make_losses(
offline,
flip_uncertainty_constraint,
target_entropy: float | None = None,
separate_critics: bool = False,
optimistic_qr: bool = False,
):
target_entropy = -0.5 * action_size if target_entropy is None else target_entropy
policy_network = sbsrl_network.policy_network
Expand Down Expand Up @@ -216,6 +218,7 @@ def actor_loss(
safety_budget: float,
penalizer: Penalizer | None,
penalizer_params: Any,
backup_qc_params: Params | None,
) -> jnp.ndarray:
dist_params = policy_network.apply(
normalizer_params, policy_params, transitions.observation
Expand Down Expand Up @@ -243,20 +246,45 @@ def actor_loss(
qr = jnp.min(qr_action, axis=-1)
qr /= reward_scaling
actor_loss = -qr.mean()
if optimistic_qr:
actor_loss = -jnp.max(jnp.mean(qr, axis=-1))
exploration_loss = (alpha * log_prob).mean()

aux = {}
if safe or uncertainty_constraint:
assert qc_network is not None
qc_action = jax.vmap(
lambda i: qc_network.apply(
normalizer_params,
qc_params,
transitions.observation,
action,
jnp.full((transitions.observation.shape[0],), i, dtype=jnp.int32),
if separate_critics:
qc_action = jax.vmap(
lambda i, p: qc_network.apply(
normalizer_params,
p,
transitions.observation,
action,
jnp.full(
(transitions.observation.shape[0],), i, dtype=jnp.int32
),
),
in_axes=(0, 0),
)(idxs, qc_params)
else:
qc_action = jax.vmap(
lambda i: qc_network.apply(
normalizer_params,
qc_params,
transitions.observation,
action,
jnp.full(
(transitions.observation.shape[0],), i, dtype=jnp.int32
),
)
)(idxs) # (E, B, n_critics*head_size)
if save_sooper_backup:
assert backup_qc_network is not None
qc_backup = backup_qc_network.apply(
normalizer_params, backup_qc_params, transitions.observation, action
)
)(idxs) # (E, B, n_critics*head_size)
qc_backup = jnp.mean(qc_backup)
aux["qc_backup"] = qc_backup
qc_action = qc_action.reshape(
ensemble_size, -1, n_critics, int(safe) + int(uncertainty_constraint)
) # -> (E, B, n_critics, head_size)
Expand Down Expand Up @@ -420,6 +448,7 @@ def policy(obs: jax.Array) -> tuple[jax.Array, jax.Array]:
return (
alpha_loss,
critic_loss_vmap,
critic_loss,
actor_loss,
compute_model_loss,
backup_critic_loss,
Expand Down
100 changes: 86 additions & 14 deletions ss2r/algorithms/sbsrl/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,61 @@ def apply(processor_params, q_params, obs, actions, idx):
)


def make_q_network_separate(
obs_size: types.ObservationSize,
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu,
n_critics: int = 2,
obs_key: str = "state",
use_bro: bool = True,
n_heads: int = 1,
head_size: int = 1,
ensemble_size: int = 10,
embedding_dim: int = 4,
) -> networks.FeedForwardNetwork:
"""Creates a value network."""

class QModule(linen.Module):
"""Q Module."""

n_critics: int

@linen.compact
def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray, idx: jnp.ndarray):
hidden = jnp.concatenate([obs, actions], axis=-1)

res = []
net = BroNet if use_bro else MLP
for _ in range(self.n_critics):
q = net( # type: ignore
layer_sizes=list(hidden_layer_sizes) + [head_size],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform(),
num_heads=n_heads,
)(hidden)
res.append(q)
return jnp.concatenate(res, axis=-1)

q_module = QModule(n_critics=n_critics)

def apply(processor_params, q_params, obs, actions, idx):
obs = preprocess_observations_fn(obs, processor_params)
obs = obs if isinstance(obs, jax.Array) else obs[obs_key]
idx = jnp.asarray(idx, dtype=jnp.int32)
return q_module.apply(q_params, obs, actions, idx)

obs_size = _get_obs_state_size(obs_size, obs_key)
dummy_obs = jnp.zeros((1, obs_size))
dummy_action = jnp.zeros((1, action_size))
dummy_idx = jnp.zeros((1,), dtype=jnp.int32)
return networks.FeedForwardNetwork(
init=lambda key: q_module.init(key, dummy_obs, dummy_action, dummy_idx),
apply=apply,
)


def make_sbsrl_networks(
observation_size: types.ObservationSize,
action_size: int,
Expand All @@ -213,6 +268,7 @@ def make_sbsrl_networks(
save_sooper_backup: bool = False,
ensemble_size: int = 10,
embedding_dim: int = 4,
separate_critics: bool = False,
) -> SBSRLNetworks:
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
Expand Down Expand Up @@ -240,20 +296,36 @@ def make_sbsrl_networks(
)
if safe or uncertainty_constraint:
n_outputs_qc = int(safe) + int(uncertainty_constraint)
qc_network = make_q_network_ensemble(
observation_size,
action_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation,
obs_key=value_obs_key,
use_bro=use_bro,
n_critics=n_critics,
n_heads=n_heads,
ensemble_size=ensemble_size,
embedding_dim=embedding_dim,
head_size=n_outputs_qc,
)
if separate_critics:
qc_network = make_q_network_separate(
observation_size,
action_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation,
obs_key=value_obs_key,
use_bro=use_bro,
n_critics=n_critics,
n_heads=n_heads,
ensemble_size=ensemble_size,
embedding_dim=embedding_dim,
head_size=n_outputs_qc,
)
else:
qc_network = make_q_network_ensemble(
observation_size,
action_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation,
obs_key=value_obs_key,
use_bro=use_bro,
n_critics=n_critics,
n_heads=n_heads,
ensemble_size=ensemble_size,
embedding_dim=embedding_dim,
head_size=n_outputs_qc,
)
old_apply = qc_network.apply
qc_network.apply = lambda *args, **kwargs: jnn.softplus(
old_apply(*args, **kwargs)
Expand Down
88 changes: 63 additions & 25 deletions ss2r/algorithms/sbsrl/on_policy_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def make_on_policy_training_step(
ensemble_size,
sac_batch_size,
normalize_fn,
separate_critics,
ensemble_index,
) -> TrainingStepFn:
def split_transitions_ensemble(
transitions: Transition, ensemble_axis: int = 1
Expand Down Expand Up @@ -99,14 +101,16 @@ def _reshape_leaf(x, name):
return trans_per_ens

def compress_transitions_ensemble(
transitions: Transition, ensemble_axis: int = 1
transitions: Transition, ensemble_index, ensemble_axis: int = 1
) -> Transition:
def _reduce_leaf(x: Any, name: str):
if isinstance(x, dict):
return {k: _reduce_leaf(v, k) for k, v in x.items()}
x_arr = jnp.asarray(x)
if name in ("observation", "action"):
return x_arr
if ensemble_index != -1:
return jnp.take(x_arr, ensemble_index, axis=ensemble_axis)
return jnp.mean(x_arr, axis=ensemble_axis)

replacements = {
Expand Down Expand Up @@ -159,15 +163,15 @@ def sgd_step(
)
if save_sooper_backup:
compressed_transitions = compress_transitions_ensemble(
transitions, ensemble_axis=1
transitions, ensemble_index, ensemble_axis=1
)
(
backup_critic_loss,
backup_qr_params,
backup_qr_optimizer_state,
) = backup_critic_update(
training_state.backup_qr_params,
training_state.behavior_policy_params, # TODO: Is it correct to use a common policy for backup and behavior?
training_state.behavior_policy_params,
training_state.normalizer_params,
training_state.backup_target_qr_params,
alpha,
Expand All @@ -188,24 +192,55 @@ def sgd_step(
key, key_cost = jax.random.split(key)
_, *ens_keys_cost = jax.random.split(key_cost, ensemble_size + 1)
ens_keys_cost = jnp.stack(ens_keys_cost)
(
cost_loss,
behavior_qc_params,
behavior_qc_optimizer_state,
) = cost_critic_update(
training_state.behavior_qc_params,
training_state.behavior_policy_params,
training_state.normalizer_params,
training_state.behavior_target_qc_params,
alpha,
trans_per_ens,
ens_keys_cost,
cost_q_transform,
safe,
uncertainty_constraint,
optimizer_state=training_state.behavior_qc_optimizer_state,
params=training_state.behavior_qc_params,
)

if not separate_critics:
(
cost_loss,
behavior_qc_params,
behavior_qc_optimizer_state,
) = cost_critic_update(
training_state.behavior_qc_params,
training_state.behavior_policy_params,
training_state.normalizer_params,
training_state.behavior_target_qc_params,
alpha,
trans_per_ens,
ens_keys_cost,
cost_q_transform,
safe,
uncertainty_constraint,
optimizer_state=training_state.behavior_qc_optimizer_state,
params=training_state.behavior_qc_params,
)
else:
per_member_vmap = jax.vmap(
lambda p_i, opt_i, trans_i, key_i, tq_i: cost_critic_update(
p_i,
training_state.behavior_policy_params,
training_state.normalizer_params,
tq_i,
alpha,
trans_i,
key_i,
cost_q_transform,
safe,
uncertainty_constraint,
optimizer_state=opt_i,
params=p_i,
),
in_axes=(0, 0, 0, 0, 0),
)
(
cost_loss,
behavior_qc_params,
behavior_qc_optimizer_state,
) = per_member_vmap(
training_state.behavior_qc_params,
training_state.behavior_qc_optimizer_state,
trans_per_ens,
ens_keys_cost,
training_state.behavior_target_qc_params,
)
cost_metrics["behavior_cost_critic_loss"] = cost_loss
else:
behavior_qc_params = training_state.behavior_qc_params
Expand Down Expand Up @@ -251,6 +286,7 @@ def sgd_step(
safety_budget,
penalizer,
training_state.penalizer_params,
training_state.backup_qc_params,
optimizer_state=training_state.behavior_policy_optimizer_state,
params=training_state.behavior_policy_params,
)
Expand Down Expand Up @@ -420,9 +456,10 @@ def generate_model_data(
transitions.reward.shape, disagreement
) # (B,ensemble_size)
disagreement_metrics = {"normalized_disagreement": disagreement.mean()}
disagreement_metrics["model_stage_cost"] = transitions.extras[
"state_extras"
]["cost"]
if safe:
disagreement_metrics["model_stage_cost"] = transitions.extras[
"state_extras"
]["cost"]
sac_replay_buffer_state = sac_replay_buffer.insert(
sac_replay_buffer_state, float16(transitions)
)
Expand Down Expand Up @@ -544,7 +581,8 @@ def training_step(
training_key = key
model_buffer_state, transitions = model_replay_buffer.sample(model_buffer_state)
cost_metrics = {}
cost_metrics["real_stage_cost"] = transitions.extras["state_extras"]["cost"]
if safe:
cost_metrics["real_stage_cost"] = transitions.extras["state_extras"]["cost"]
# Change the front dimension of transitions so 'update_step' is called
# grad_updates_per_step times by the scan.
tmp_transitions = jax.tree_util.tree_map(
Expand Down
Loading