diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py new file mode 100644 index 000000000..368c1c1a4 --- /dev/null +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -0,0 +1,291 @@ +import argparse +from pathlib import Path +from typing import Mapping + +import jax +import numpy as np +from brax.training.acme import running_statistics +from brax.training.agents.ppo import checkpoint as ppo_checkpoint +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.agents.ppo import networks_vision as ppo_networks_vision +from flax import linen + +from ss2r.algorithms.ppo import franka_ppo_to_onnx + +try: + import onnxruntime as ort +except ImportError: + ort = None + + +def _resolve_checkpoint_dir(path_str: str) -> Path: + path = Path(path_str).expanduser().resolve() + if not path.exists(): + raise FileNotFoundError(f"Checkpoint path does not exist: {path}") + + config_name = "ppo_network_config.json" + if (path / config_name).exists(): + return path + + step_dirs = [p for p in path.iterdir() if p.is_dir() and (p / config_name).exists()] + if not step_dirs: + raise FileNotFoundError( + f"No checkpoint step with {config_name} found under: {path}" + ) + + numeric = [p for p in step_dirs if p.name.isdigit()] + if numeric: + return max(numeric, key=lambda p: int(p.name)) + return sorted(step_dirs)[-1] + + +def _shape_tuple(raw: object) -> tuple[int, ...]: + if isinstance(raw, tuple): + return tuple(int(x) for x in raw) + if isinstance(raw, list): + return tuple(int(x) for x in raw) + if isinstance(raw, Mapping) and "shape" in raw: + return _shape_tuple(raw["shape"]) + raise TypeError(f"Unsupported shape payload: {raw}") + + +def _extract_policy_params(params) -> Mapping[str, object]: + policy_params = params[1] if isinstance(params, (tuple, list)) else params + if isinstance(policy_params, Mapping) and "params" in policy_params: + policy_params = policy_params["params"] + if not isinstance(policy_params, Mapping): + raise TypeError("Could not extract policy parameter mapping.") + return policy_params + + +def _extract_pixel_obs_shapes( + observation_size_cfg: Mapping[str, object], + policy_params: Mapping[str, object], +) -> tuple[tuple[str, ...], dict[str, tuple[int, ...]]]: + pixel_obs_keys = tuple(k for k in observation_size_cfg if k.startswith("pixels/")) + if not pixel_obs_keys: + raise ValueError( + "No pixel observation keys found in checkpoint observation_size." + ) + + obs_shapes: dict[str, tuple[int, ...]] = {} + for i, key in enumerate(pixel_obs_keys): + try: + shape = _shape_tuple(observation_size_cfg[key]) + except Exception: + shape = () + + if len(shape) == 3: + obs_shapes[key] = shape + continue + + channels = int( + np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore[index] + ) + obs_shapes[key] = (64, 64, channels) + + return pixel_obs_keys, obs_shapes + + +def _extract_state_shape( + observation_size_cfg: Mapping[str, object], + params, +) -> tuple[int, ...]: + if isinstance(params, (tuple, list)) and params: + running_stats = params[0] + mean = getattr(running_stats, "mean", None) + if isinstance(mean, Mapping) and "state" in mean: + shape = tuple(int(d) for d in np.asarray(mean["state"]).shape) # type: ignore + if shape: + return shape + + if "state" in observation_size_cfg: + try: + shape = _shape_tuple(observation_size_cfg["state"]) + if shape: + return shape + except Exception: + pass + + return (1,) + + +def _make_ort_session_options(): + opts = ort.SessionOptions() + opts.intra_op_num_threads = 1 + opts.inter_op_num_threads = 1 + return opts + + +def _onnx_input_shapes(session: "ort.InferenceSession") -> dict[str, tuple[int, ...]]: + shapes: dict[str, tuple[int, ...]] = {} + for inp in session.get_inputs(): + raw_shape = list(inp.shape)[1:] + if any((d is None or isinstance(d, str)) for d in raw_shape): + raise ValueError( + f"Cannot infer static shape for ONNX input {inp.name}: {inp.shape}." + ) + shapes[inp.name] = tuple(int(d) for d in raw_shape) + return shapes + + +def _build_random_obs( + rng: np.random.Generator, + obs_shapes: Mapping[str, tuple[int, ...]], +) -> dict[str, np.ndarray]: + return { + k: rng.standard_normal((1, *shape), dtype=np.float32) + for k, shape in obs_shapes.items() + } + + +def convert_checkpoint_to_onnx( + checkpoint_path: str, + output_path: str | None = None, + num_tests: int = 0, + atol: float = 1e-4, +) -> Path: + ckpt_dir = _resolve_checkpoint_dir(checkpoint_path) + print(f"Using PPO checkpoint: {ckpt_dir}") + + params = ppo_checkpoint.load(ckpt_dir) + config = ppo_checkpoint.load_config(ckpt_dir) + config_dict = config.to_dict() + + observation_size_cfg = config_dict.get("observation_size") + if not isinstance(observation_size_cfg, dict): + raise TypeError("Checkpoint observation_size must be a mapping for vision PPO.") + + policy_params = _extract_policy_params(params) + pixel_obs_keys, network_obs_shapes = _extract_pixel_obs_shapes( + observation_size_cfg, policy_params + ) + + network_factory_kwargs = config_dict.get("network_factory_kwargs", {}) + if not isinstance(network_factory_kwargs, dict): + network_factory_kwargs = {} + network_factory_kwargs = dict(network_factory_kwargs) + + activation_name = str(network_factory_kwargs.pop("activation_name", "")) + activation = getattr(linen, activation_name, None) if activation_name else None + if activation_name and activation is None: + raise ValueError( + f"Unsupported activation_name in checkpoint: {activation_name!r}" + ) + + # Force pixel-only policy reconstruction. + network_factory_kwargs["policy_obs_key"] = "" + network_factory_kwargs["value_obs_key"] = "" + + normalize = ( + running_statistics.normalize + if bool(config_dict.get("normalize_observations", False)) + else (lambda x, y: x) + ) + + ppo_network_kwargs = dict(network_factory_kwargs) + if activation is not None: + ppo_network_kwargs["activation"] = activation + + ppo_network = ppo_networks_vision.make_ppo_networks_vision( + observation_size=network_obs_shapes, + action_size=int(config_dict["action_size"]), + preprocess_observations_fn=normalize, + **ppo_network_kwargs, + ) + make_inference_fn = ppo_networks.make_inference_fn(ppo_network) + + export_obs_shapes = dict(network_obs_shapes) + export_obs_shapes["state"] = _extract_state_shape(observation_size_cfg, params) + + model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( + make_inference_fn=make_inference_fn, + params=params, + observation_shapes=export_obs_shapes, + pixel_obs_keys=pixel_obs_keys, + state_obs_key="", + normalise_channels=True, + ) + + if output_path is None: + output_file = ckpt_dir / "policy.onnx" + else: + output_file = Path(output_path).expanduser().resolve() + output_file.parent.mkdir(parents=True, exist_ok=True) + output_file.write_bytes(model_proto.SerializeToString()) + print(f"Wrote ONNX model: {output_file}") + + if num_tests <= 0: + return output_file + if ort is None: + print("onnxruntime not installed, skipping parity checks.") + return output_file + + session = ort.InferenceSession( + model_proto.SerializeToString(), + sess_options=_make_ort_session_options(), + providers=["CPUExecutionProvider"], + ) + onnx_shapes = _onnx_input_shapes(session) + jax_policy = make_inference_fn(params, deterministic=True) + jax_input_keys = set(network_obs_shapes) + + rng = np.random.default_rng(0) + max_err = 0.0 + for i in range(num_tests): + obs = _build_random_obs(rng, onnx_shapes) + onnx_action = session.run(["continuous_actions"], obs)[0][0] + jax_obs = { + k: jax.numpy.asarray(v) for k, v in obs.items() if k in jax_input_keys + } + jax_action = np.asarray(jax_policy(jax_obs, jax.random.PRNGKey(i))[0][0]) + err = float(np.max(np.abs(onnx_action - jax_action))) + max_err = max(max_err, err) + print(f"sample={i} max_abs_err={err:.6e}") + + print(f"overall max_abs_err={max_err:.6e} (atol={atol:.1e})") + if max_err > atol: + raise AssertionError( + f"ONNX/JAX mismatch: max_abs_err={max_err:.6e} > atol={atol}" + ) + + return output_file + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert PPO vision policy checkpoint to ONNX." + ) + parser.add_argument( + "--checkpoint_path", + required=True, + help="Checkpoint step directory or parent directory containing step subdirs.", + ) + parser.add_argument( + "--output_path", + default=None, + help="Output ONNX file path (default: /policy.onnx).", + ) + parser.add_argument( + "--num_tests", + type=int, + default=0, + help="Number of ONNX-vs-JAX parity tests to run.", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-4, + help="Absolute tolerance for parity checks.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + convert_checkpoint_to_onnx( + checkpoint_path=args.checkpoint_path, + output_path=args.output_path, + num_tests=args.num_tests, + atol=args.atol, + ) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py new file mode 100644 index 000000000..a84d73c1d --- /dev/null +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -0,0 +1,211 @@ +"""Train PPO on PandaPickCubeCartesianExtended from pixels via SS2R env factory. + +This script follows the Madrona-MJX tutorial flow: +1) Build a vision environment with domain randomization. +2) Train a vision PPO policy. +3) Print reward metrics during training. + +The environment is always created with: +`ss2r.benchmark_suites.make_mujoco_playground_envs`. +""" + +import functools +from datetime import datetime +from pathlib import Path +from typing import Any + +import brax.training.agents.ppo.train as brax_ppo_train +from absl import app, flags +from brax.training.agents.ppo import networks_vision as ppo_networks_vision +from flax import linen +from ml_collections import config_dict +from mujoco_playground.config import manipulation_params + +from ss2r import benchmark_suites +from ss2r.benchmark_suites.mujoco_playground.pick_cartesian import ( + pick_cartesian as pick_cartesian_task, +) + +_ENV_NAME = "PandaPickCubeCartesianExtended" +_BASE_CONFIG_ENV_NAME = "PandaPickCubeCartesian" + +_NUM_ENVS = flags.DEFINE_integer("num_envs", 1024, "Number of parallel environments") +_NUM_TIMESTEPS = flags.DEFINE_integer( + "num_timesteps", 7_000_000, "Total environment steps for PPO training" +) +_SEED = flags.DEFINE_integer("seed", 0, "PRNG seed") +_SAVE_CHECKPOINT_PATH = flags.DEFINE_string( + "save_checkpoint_path", + "franka_experiments/checkpoints/pick_cartesian_vision_ppo", + "Directory to save Brax PPO checkpoints. Set empty string to disable.", +) +_TRAIN_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( + "train_domain_randomization", True, "Enable domain randomization for training" +) +_EVAL_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( + "eval_domain_randomization", True, "Enable domain randomization for evaluation" +) +_RENDER_WIDTH = flags.DEFINE_integer("render_width", 64, "Render width") +_RENDER_HEIGHT = flags.DEFINE_integer("render_height", 64, "Render height") +_USE_RASTERIZER = flags.DEFINE_boolean( + "use_rasterizer", False, "Use rasterizer backend for rendering" +) + + +def _make_ppo_networks_vision_ckpt_compatible( + observation_size, + action_size, + preprocess_observations_fn, + policy_hidden_layer_sizes=(256, 256), + value_hidden_layer_sizes=(256, 256), + normalise_channels=False, + policy_obs_key="", + value_obs_key="", + activation_name: str = "relu", +): + # FIXME: Brax checkpoint.network_config rejects non-default `activation` kwargs. + # Route activation through `activation_name` so checkpoints can be saved while + # still building the ReLU network used in this experiment. + activation = getattr(linen, activation_name, None) + if activation is None: + raise ValueError(f"Unsupported activation_name={activation_name!r}") + return ppo_networks_vision.make_ppo_networks_vision( + observation_size=observation_size, + action_size=action_size, + preprocess_observations_fn=preprocess_observations_fn, + policy_hidden_layer_sizes=policy_hidden_layer_sizes, + value_hidden_layer_sizes=value_hidden_layer_sizes, + activation=activation, + normalise_channels=normalise_channels, + policy_obs_key=policy_obs_key, + value_obs_key=value_obs_key, + ) + + +def _set_nested(cfg: config_dict.ConfigDict, key: str, value: Any) -> None: + keys = key.split(".") + node = cfg + for part in keys[:-1]: + node = node[part] + node[keys[-1]] = value + + +def _build_env_and_cfg(): + num_envs = _NUM_ENVS.value + env_cfg = pick_cartesian_task.default_config() + episode_length = int(4 / env_cfg.ctrl_dt) + + overrides = { + "use_ball": False, + "use_x": False, + "episode_length": episode_length, + "vision": True, + "obs_noise.brightness": [0.75, 2.0], + "vision_config.use_rasterizer": _USE_RASTERIZER.value, + "vision_config.render_batch_size": num_envs, + "vision_config.render_width": _RENDER_WIDTH.value, + "vision_config.render_height": _RENDER_HEIGHT.value, + "box_init_range": 0.1, + "action_history_length": 5, + "success_threshold": 0.03, + } + for k, v in overrides.items(): + _set_nested(env_cfg, k, v) + + cfg = config_dict.ConfigDict() + cfg.environment = config_dict.ConfigDict() + cfg.environment.domain_name = "mujoco_playground" + cfg.environment.task_name = _ENV_NAME + cfg.environment.task_params = env_cfg.to_dict() + cfg.environment.train_params = config_dict.ConfigDict() + cfg.environment.eval_params = config_dict.ConfigDict() + + cfg.agent = config_dict.ConfigDict() + cfg.agent.use_vision = True + + cfg.training = config_dict.ConfigDict() + cfg.training.seed = _SEED.value + cfg.training.num_envs = num_envs + cfg.training.num_eval_envs = num_envs + cfg.training.train_domain_randomization = _TRAIN_DOMAIN_RANDOMIZATION.value + cfg.training.eval_domain_randomization = _EVAL_DOMAIN_RANDOMIZATION.value + cfg.training.episode_length = episode_length + cfg.training.action_repeat = 1 + cfg.training.hard_resets = False + cfg.training.nonepisodic = False + cfg.training.action_delay = config_dict.ConfigDict( + {"enable": False, "max_delay": 0} + ) + + train_env, _ = benchmark_suites.make_mujoco_playground_envs( + cfg, lambda env: env, lambda env: env + ) + return train_env, episode_length + + +def main(argv): + del argv + + print(f"Using Brax PPO trainer from: {brax_ppo_train.__file__}") + save_checkpoint_path = None + if _SAVE_CHECKPOINT_PATH.value: + save_checkpoint_path = str( + Path(_SAVE_CHECKPOINT_PATH.value).expanduser().resolve() + ) + Path(save_checkpoint_path).mkdir(parents=True, exist_ok=True) + print( + "Checkpoint saving: " + f"{save_checkpoint_path if save_checkpoint_path else 'disabled'}" + ) + + train_env, episode_length = _build_env_and_cfg() + num_envs = _NUM_ENVS.value + + network_factory = functools.partial( + _make_ppo_networks_vision_ckpt_compatible, + policy_hidden_layer_sizes=[256, 256], + value_hidden_layer_sizes=[256, 256], + activation_name="relu", + normalise_channels=True, + ) + + ppo_params = manipulation_params.brax_vision_ppo_config(_BASE_CONFIG_ENV_NAME) + ppo_params.num_timesteps = _NUM_TIMESTEPS.value + ppo_params.num_envs = num_envs + ppo_params.num_eval_envs = num_envs + ppo_params.episode_length = episode_length + ppo_params.action_repeat = 1 + del ppo_params.network_factory + ppo_params.network_factory = network_factory + + times = [datetime.now()] + + def progress(num_steps, metrics): + if "eval/episode_reward" in metrics: + print( + f"{num_steps}: eval/episode_reward={metrics['eval/episode_reward']:.3f} " + f"+- {metrics.get('eval/episode_reward_std', 0.0):.3f}" + ) + times.append(datetime.now()) + + train_kwargs = dict(ppo_params) + train_kwargs.update( + { + "augment_pixels": True, + "wrap_env": False, + "madrona_backend": True, + "progress_fn": progress, + "seed": _SEED.value, + "save_checkpoint_path": save_checkpoint_path, + } + ) + train_fn = functools.partial(brax_ppo_train.train, **train_kwargs) + + _ = train_fn(environment=train_env, eval_env=None) + if len(times) > 1: + print(f"time to jit: {times[1] - times[0]}") + print(f"time to train: {times[-1] - times[1]}") + + +if __name__ == "__main__": + app.run(main) diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py new file mode 100644 index 000000000..3182d9486 --- /dev/null +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import logging +from collections.abc import Mapping, Sequence +from typing import Any + +import jax +import numpy as np + +try: + import tensorflow as tf + import tf2onnx + from tensorflow.keras import layers # type: ignore +except ImportError: + tf = None + layers = None + tf2onnx = None + logging.warning("TensorFlow is not installed. Skipping conversion to ONNX.") + + +logger = logging.getLogger(__name__) + + +def _get_path(tree: Any, *path: str, default: Any = None) -> Any: + cur = tree + for key in path: + if cur is None: + return default + if isinstance(cur, Mapping): + cur = cur.get(key, default) + else: + cur = getattr(cur, key, default) + if cur is default: + return default + return cur + + +def _extract_policy_params(params: Any) -> Mapping[str, Any]: + if isinstance(params, (tuple, list)): + if len(params) < 2: + raise ValueError("Expected params as (normalizer, policy, value) tuple.") + policy_params = params[1] + else: + policy_params = params + + if isinstance(policy_params, Mapping) and "params" in policy_params: + policy_params = policy_params["params"] + if not isinstance(policy_params, Mapping): + raise TypeError("Could not extract policy parameter mapping.") + return policy_params + + +def _sorted_hidden_keys(mlp_params: Mapping[str, Any]) -> list[str]: + hidden_keys = [k for k in mlp_params.keys() if k.startswith("hidden_")] + hidden_keys.sort(key=lambda x: int(x.split("_")[-1])) + return hidden_keys + + +def _infer_action_size(policy_params: Mapping[str, Any]) -> int: + mlp_params = policy_params["MLP_0"] + hidden_keys = _sorted_hidden_keys(mlp_params) + if not hidden_keys: + raise ValueError("MLP_0 has no hidden_* layers.") + last_bias = np.asarray(mlp_params[hidden_keys[-1]]["bias"]) + if last_bias.shape[-1] % 2 != 0: # type: ignore + raise ValueError( + f"Expected even policy logits dimension, got {last_bias.shape[-1]}." # type: ignore + ) + return int(last_bias.shape[-1] // 2) # type: ignore + + +def _infer_hidden_layers(policy_params: Mapping[str, Any]) -> tuple[int, ...]: + mlp_params = policy_params["MLP_0"] + hidden_keys = _sorted_hidden_keys(mlp_params) + if len(hidden_keys) < 2: + return () + return tuple( + int(np.asarray(mlp_params[k]["bias"]).shape[0]) # type: ignore + for k in hidden_keys[:-1] # type: ignore + ) + + +def _resolve_render_hw( + cfg: Any, default_h: int = 64, default_w: int = 64 +) -> tuple[int, int]: + h = _get_path(cfg, "environment", "task_params", "vision_config", "render_height") + w = _get_path(cfg, "environment", "task_params", "vision_config", "render_width") + if h is None or w is None: + return default_h, default_w + return int(h), int(w) + + +def _resolve_hidden_layers(cfg: Any, fallback: tuple[int, ...]) -> tuple[int, ...]: + from_cfg = _get_path(cfg, "agent", "policy_hidden_layer_sizes") + if from_cfg is None: + return fallback + return tuple(int(x) for x in from_cfg) + + +def _resolve_tf_activation(cfg: Any): + activation_name = _get_path(cfg, "agent", "activation") + if isinstance(activation_name, str): + activation = getattr(tf.nn, activation_name, None) + if activation is not None: + return activation + # Matches train_pick_cartesian_vision_ppo.py + return tf.nn.relu + + +if tf is not None and layers is not None: + + class VisionPPOPolicy(tf.keras.Model): + """TensorFlow equivalent of Brax VisionMLP policy head for PPO.""" + + def __init__( + self, + action_size: int, + pixel_obs_keys: Sequence[str], + hidden_layer_sizes: Sequence[int], + activation: Any, + normalise_channels: bool = True, + state_obs_key: str = "", + **kwargs, + ): + super().__init__(**kwargs) + self.action_size = action_size + self.pixel_obs_keys = tuple(pixel_obs_keys) + self.normalise_channels = normalise_channels + del state_obs_key # state is intentionally not used by this policy. + + self.cnn_blocks = [] + for i, _ in enumerate(self.pixel_obs_keys): + cnn = tf.keras.Sequential(name=f"CNN_{i}") + cnn.add( + layers.Conv2D( + filters=32, + kernel_size=(8, 8), + strides=(4, 4), + padding="same", + activation=tf.nn.relu, + use_bias=False, + name="Conv_0", + ) + ) + cnn.add( + layers.Conv2D( + filters=64, + kernel_size=(4, 4), + strides=(2, 2), + padding="same", + activation=tf.nn.relu, + use_bias=False, + name="Conv_1", + ) + ) + cnn.add( + layers.Conv2D( + filters=64, + kernel_size=(3, 3), + strides=(1, 1), + padding="same", + activation=tf.nn.relu, + use_bias=False, + name="Conv_2", + ) + ) + self.cnn_blocks.append(cnn) + + self.mlp_block = tf.keras.Sequential(name="MLP_0") + layer_sizes = list(hidden_layer_sizes) + [action_size * 2] + for i, size in enumerate(layer_sizes): + self.mlp_block.add( + layers.Dense( + units=size, + activation=activation if i < len(layer_sizes) - 1 else None, + name=f"hidden_{i}", + ) + ) + + @staticmethod + def _normalise_channels(x: tf.Tensor, eps: float = 1e-6) -> tf.Tensor: + mean = tf.reduce_mean(x, axis=(1, 2), keepdims=True) + var = tf.reduce_mean(tf.square(x - mean), axis=(1, 2), keepdims=True) + return (x - mean) * tf.math.rsqrt(var + eps) + + def call(self, obs: Mapping[str, tf.Tensor]) -> tf.Tensor: + cnn_outs = [] + for i, key in enumerate(self.pixel_obs_keys): + hidden = obs[key] + if self.normalise_channels: + hidden = self._normalise_channels(hidden) + hidden = self.cnn_blocks[i](hidden) + hidden = tf.reduce_mean(hidden, axis=(1, 2)) + cnn_outs.append(hidden) + + logits = self.mlp_block(tf.concat(cnn_outs, axis=-1)) + loc, _ = tf.split(logits, 2, axis=-1) + action = tf.tanh(loc) + return action + +else: + + class VisionPPOPolicy: # type: ignore[no-redef] + pass + + +def transfer_weights( + policy_params: Mapping[str, Any], tf_model: tf.keras.Model +) -> None: + """Copies Flax PPO vision policy weights into TF model.""" + cnn_names = sorted( + [k for k in policy_params.keys() if k.startswith("CNN_")], + key=lambda x: int(x.split("_")[-1]), + ) + for i, cnn_name in enumerate(cnn_names): + tf_cnn = tf_model.get_layer(f"CNN_{i}") + for conv_name in ("Conv_0", "Conv_1", "Conv_2"): + kernel = np.asarray(policy_params[cnn_name][conv_name]["kernel"]) + tf_layer = tf_cnn.get_layer(conv_name) + tf_layer.set_weights([kernel]) + + tf_mlp = tf_model.get_layer("MLP_0") + hidden_keys = _sorted_hidden_keys(policy_params["MLP_0"]) + for hidden_name in hidden_keys: + layer_params = policy_params["MLP_0"][hidden_name] + kernel = np.asarray(layer_params["kernel"]) + bias = np.asarray(layer_params["bias"]) + tf_layer = tf_mlp.get_layer(hidden_name) + tf_layer.set_weights([kernel, bias]) + + +def convert_policy_to_onnx( + make_inference_fn: Any, + params: Any, + cfg: Any | None = None, + observation_shapes: Mapping[str, tuple[int, ...]] | None = None, + pixel_obs_keys: Sequence[str] | None = None, + state_obs_key: str = "", + normalise_channels: bool = True, +): + """Converts PPO vision policy params to ONNX via TensorFlow.""" + if tf is None or tf2onnx is None or layers is None: + raise ImportError("TensorFlow/tf2onnx is required for ONNX export.") + + policy_params = _extract_policy_params(params) + action_size = _infer_action_size(policy_params) + hidden_layers = _resolve_hidden_layers(cfg, _infer_hidden_layers(policy_params)) + activation = _resolve_tf_activation(cfg) + + cnn_names = sorted( + [k for k in policy_params.keys() if k.startswith("CNN_")], + key=lambda x: int(x.split("_")[-1]), + ) + if not cnn_names: + raise ValueError("No CNN_* modules found in policy params.") + + if pixel_obs_keys is None: + pixel_obs_keys = tuple(f"pixels/view_{i}" for i in range(len(cnn_names))) + if len(pixel_obs_keys) != len(cnn_names): + raise ValueError( + f"Expected {len(cnn_names)} pixel keys, got {len(pixel_obs_keys)}." + ) + + if observation_shapes is None: + observation_shapes = {} + else: + observation_shapes = dict(observation_shapes) + + h, w = _resolve_render_hw(cfg) + used_observation_shapes: dict[str, tuple[int, ...]] = {} + for i, key in enumerate(pixel_obs_keys): + if key in observation_shapes: + used_observation_shapes[key] = tuple(observation_shapes[key]) + continue + channels = int( + np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore + ) + used_observation_shapes[key] = (h, w, channels) + del state_obs_key # state is intentionally never passed to the policy network. + + tf_policy_network = VisionPPOPolicy( + action_size=action_size, + pixel_obs_keys=pixel_obs_keys, + hidden_layer_sizes=hidden_layers, + activation=activation, + normalise_channels=normalise_channels, + ) + + model_input_shapes = dict(used_observation_shapes) + if "state" in observation_shapes: + model_input_shapes["state"] = tuple(observation_shapes["state"]) + dummy_obs = { + k: np.ones((1, *shape), dtype=np.float32) + for k, shape in model_input_shapes.items() + } + tf_policy_network(dummy_obs).numpy() + transfer_weights(policy_params, tf_policy_network) + + inference_fn = make_inference_fn(params, deterministic=True) + jax_obs = { + k: jax.numpy.asarray(v) + for k, v in dummy_obs.items() + if k in used_observation_shapes + } + jax_pred = np.asarray(inference_fn(jax_obs, jax.random.PRNGKey(0))[0][0]) + tf_pred = np.asarray(tf_policy_network(dummy_obs).numpy()[0]) + max_abs_err = float(np.max(np.abs(jax_pred - tf_pred))) + logger.info("PPO vision ONNX export sanity max_abs_err=%.6e", max_abs_err) + + tf_policy_network.output_names = ["continuous_actions"] + model_proto, _ = tf2onnx.convert.from_keras( + tf_policy_network, + input_signature=[ + { + k: tf.TensorSpec([1, *shape], tf.float32, name=k) + for k, shape in model_input_shapes.items() + } + ], + opset=11, + ) + return model_proto + + +def make_franka_policy( + make_policy_fn: Any, params: Any, cfg: Any | None = None +) -> bytes: + """Builds and serializes an ONNX PPO vision policy for Franka pick tasks.""" + policy_params = _extract_policy_params(params) + cnn_names = sorted( + [k for k in policy_params.keys() if k.startswith("CNN_")], + key=lambda x: int(x.split("_")[-1]), + ) + h, w = _resolve_render_hw(cfg) + obs_shapes = {} + obs_keys = [] + for i, cnn_name in enumerate(cnn_names): + channels = int(np.asarray(policy_params[cnn_name]["Conv_0"]["kernel"]).shape[2]) # type: ignore + key = f"pixels/view_{i}" + obs_keys.append(key) + obs_shapes[key] = (h, w, channels) + + model_proto = convert_policy_to_onnx( + make_policy_fn, + params, + cfg=cfg, + observation_shapes=obs_shapes, + pixel_obs_keys=obs_keys, + state_obs_key="", + normalise_channels=True, + ) + return model_proto.SerializeToString()