From b87cbd32ce8ad43dd2c159511bef104a9b708185 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:34:21 +0100 Subject: [PATCH 01/18] Add training script --- .../train_pick_cartesian_vision_ppo.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 franka_experiments/train_pick_cartesian_vision_ppo.py diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py new file mode 100644 index 000000000..aeb9c97b1 --- /dev/null +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -0,0 +1,156 @@ +"""Train PPO on PandaPickCubeCartesianExtended from pixels via SS2R env factory. + +This script follows the Madrona-MJX tutorial flow: +1) Build a vision environment with domain randomization. +2) Train a vision PPO policy. +3) Print reward metrics during training. + +The environment is always created with: +`ss2r.benchmark_suites.make_mujoco_playground_envs`. +""" + +import functools +from datetime import datetime +from typing import Any + +from absl import app, flags +from brax.training.agents.ppo import networks_vision as ppo_networks_vision +from brax.training.agents.ppo import train as ppo +from flax import linen +from ml_collections import config_dict +from mujoco_playground import registry +from mujoco_playground.config import manipulation_params + +from ss2r import benchmark_suites + +_ENV_NAME = "PandaPickCubeCartesianExtended" +_BASE_CONFIG_ENV_NAME = "PandaPickCubeCartesian" + +_NUM_ENVS = flags.DEFINE_integer("num_envs", 1024, "Number of parallel environments") +_NUM_TIMESTEPS = flags.DEFINE_integer( + "num_timesteps", 7_000_000, "Total environment steps for PPO training" +) +_SEED = flags.DEFINE_integer("seed", 0, "PRNG seed") +_TRAIN_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( + "train_domain_randomization", True, "Enable domain randomization for training" +) +_EVAL_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( + "eval_domain_randomization", True, "Enable domain randomization for evaluation" +) +_RENDER_WIDTH = flags.DEFINE_integer("render_width", 64, "Render width") +_RENDER_HEIGHT = flags.DEFINE_integer("render_height", 64, "Render height") +_USE_RASTERIZER = flags.DEFINE_boolean( + "use_rasterizer", False, "Use rasterizer backend for rendering" +) + + +def _set_nested(cfg: config_dict.ConfigDict, key: str, value: Any) -> None: + keys = key.split(".") + node = cfg + for part in keys[:-1]: + node = node[part] + node[keys[-1]] = value + + +def _build_env_and_cfg(): + num_envs = _NUM_ENVS.value + env_cfg = registry.get_default_config(_ENV_NAME) + episode_length = int(4 / env_cfg.ctrl_dt) + + overrides = { + "episode_length": episode_length, + "vision": True, + "obs_noise.brightness": [0.75, 2.0], + "vision_config.use_rasterizer": _USE_RASTERIZER.value, + "vision_config.render_batch_size": num_envs, + "vision_config.render_width": _RENDER_WIDTH.value, + "vision_config.render_height": _RENDER_HEIGHT.value, + "box_init_range": 0.1, + "action_history_length": 5, + "success_threshold": 0.03, + } + for k, v in overrides.items(): + _set_nested(env_cfg, k, v) + + cfg = config_dict.ConfigDict() + cfg.environment = config_dict.ConfigDict() + cfg.environment.domain_name = "mujoco_playground" + cfg.environment.task_name = _ENV_NAME + cfg.environment.task_params = env_cfg.to_dict() + cfg.environment.train_params = config_dict.ConfigDict() + cfg.environment.eval_params = config_dict.ConfigDict() + + cfg.agent = config_dict.ConfigDict() + cfg.agent.use_vision = True + + cfg.training = config_dict.ConfigDict() + cfg.training.seed = _SEED.value + cfg.training.num_envs = num_envs + cfg.training.num_eval_envs = num_envs + cfg.training.train_domain_randomization = _TRAIN_DOMAIN_RANDOMIZATION.value + cfg.training.eval_domain_randomization = _EVAL_DOMAIN_RANDOMIZATION.value + cfg.training.episode_length = episode_length + cfg.training.action_repeat = 1 + cfg.training.hard_resets = False + cfg.training.nonepisodic = False + cfg.training.action_delay = config_dict.ConfigDict( + {"enable": False, "max_delay": 0} + ) + + train_env, _ = benchmark_suites.make_mujoco_playground_envs( + cfg, lambda env: env, lambda env: env + ) + return train_env, episode_length + + +def main(argv): + del argv + + train_env, episode_length = _build_env_and_cfg() + num_envs = _NUM_ENVS.value + + network_factory = functools.partial( + ppo_networks_vision.make_ppo_networks_vision, + policy_hidden_layer_sizes=[256, 256], + value_hidden_layer_sizes=[256, 256], + activation=linen.relu, + normalise_channels=True, + ) + + ppo_params = manipulation_params.brax_vision_ppo_config(_BASE_CONFIG_ENV_NAME) + ppo_params.num_timesteps = _NUM_TIMESTEPS.value + ppo_params.num_envs = num_envs + ppo_params.num_eval_envs = num_envs + ppo_params.episode_length = episode_length + ppo_params.action_repeat = 1 + del ppo_params.network_factory + ppo_params.network_factory = network_factory + + times = [datetime.now()] + + def progress(num_steps, metrics): + if "eval/episode_reward" in metrics: + print( + f"{num_steps}: eval/episode_reward={metrics['eval/episode_reward']:.3f} " + f"+- {metrics.get('eval/episode_reward_std', 0.0):.3f}" + ) + times.append(datetime.now()) + + train_fn = functools.partial( + ppo.train, + augment_pixels=True, + wrap_env=False, + madrona_backend=True, + progress_fn=progress, + seed=_SEED.value, + **dict(ppo_params), + ) + + _ = train_fn(environment=train_env, eval_env=None) + if len(times) > 1: + print(f"time to jit: {times[1] - times[0]}") + print(f"time to train: {times[-1] - times[1]}") + + +if __name__ == "__main__": + app.run(main) From d87ed0967563dfeccbf8414fb6fdd64479e18007 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:37:21 +0100 Subject: [PATCH 02/18] Fixes --- franka_experiments/train_pick_cartesian_vision_ppo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index aeb9c97b1..e3cf92b1f 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -18,10 +18,12 @@ from brax.training.agents.ppo import train as ppo from flax import linen from ml_collections import config_dict -from mujoco_playground import registry from mujoco_playground.config import manipulation_params from ss2r import benchmark_suites +from ss2r.benchmark_suites.mujoco_playground.pick_cartesian import ( + pick_cartesian as pick_cartesian_task, +) _ENV_NAME = "PandaPickCubeCartesianExtended" _BASE_CONFIG_ENV_NAME = "PandaPickCubeCartesian" @@ -54,7 +56,7 @@ def _set_nested(cfg: config_dict.ConfigDict, key: str, value: Any) -> None: def _build_env_and_cfg(): num_envs = _NUM_ENVS.value - env_cfg = registry.get_default_config(_ENV_NAME) + env_cfg = pick_cartesian_task.default_config() episode_length = int(4 / env_cfg.ctrl_dt) overrides = { From d11e680197b9547bcec47511c868f9cecf74d7a4 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:38:43 +0100 Subject: [PATCH 03/18] Fixes --- franka_experiments/train_pick_cartesian_vision_ppo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index e3cf92b1f..a34dd9ed6 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -60,6 +60,8 @@ def _build_env_and_cfg(): episode_length = int(4 / env_cfg.ctrl_dt) overrides = { + "use_ball": False, + "use_x": False, "episode_length": episode_length, "vision": True, "obs_noise.brightness": [0.75, 2.0], From d6cf3c6b567e04a3e13d6d8ac0eaf88b155e9f7c Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:43:56 +0100 Subject: [PATCH 04/18] Use standard brax ppo --- .../train_pick_cartesian_vision_ppo.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index a34dd9ed6..6945bf3c3 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -13,9 +13,9 @@ from datetime import datetime from typing import Any +import brax.training.agents.ppo.train as brax_ppo_train from absl import app, flags from brax.training.agents.ppo import networks_vision as ppo_networks_vision -from brax.training.agents.ppo import train as ppo from flax import linen from ml_collections import config_dict from mujoco_playground.config import manipulation_params @@ -110,6 +110,7 @@ def _build_env_and_cfg(): def main(argv): del argv + print(f"Using Brax PPO trainer from: {brax_ppo_train.__file__}") train_env, episode_length = _build_env_and_cfg() num_envs = _NUM_ENVS.value @@ -140,15 +141,17 @@ def progress(num_steps, metrics): ) times.append(datetime.now()) - train_fn = functools.partial( - ppo.train, - augment_pixels=True, - wrap_env=False, - madrona_backend=True, - progress_fn=progress, - seed=_SEED.value, - **dict(ppo_params), + train_kwargs = dict(ppo_params) + train_kwargs.update( + { + "augment_pixels": True, + "wrap_env": False, + "madrona_backend": True, + "progress_fn": progress, + "seed": _SEED.value, + } ) + train_fn = functools.partial(brax_ppo_train.train, **train_kwargs) _ = train_fn(environment=train_env, eval_env=None) if len(times) > 1: From 3617872a78593f6f792f36a73add1edd33a1d73c Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:46:06 +0100 Subject: [PATCH 05/18] Add vision conversion scripts --- franka_experiments/ppo_vision_flax_to_onnx.py | 38 ++ ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 336 ++++++++++++++++++ 2 files changed, 374 insertions(+) create mode 100644 franka_experiments/ppo_vision_flax_to_onnx.py create mode 100644 ss2r/algorithms/ppo/franka_ppo_to_onnx.py diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py new file mode 100644 index 000000000..2f2feed6b --- /dev/null +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -0,0 +1,38 @@ +import jax +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.agents.ppo import networks_vision as ppo_networks_vision +from flax import linen + +from ss2r.algorithms.ppo import franka_ppo_to_onnx + + +def test_policy_to_onnx_export(): + obs_shape = (64, 64, 1) + action_size = 3 + ppo_network = ppo_networks_vision.make_ppo_networks_vision( + observation_size={"pixels/view_0": obs_shape, "state": (10,)}, + action_size=action_size, + policy_hidden_layer_sizes=(256, 256), + value_hidden_layer_sizes=(256, 256), + activation=linen.relu, + normalise_channels=True, + ) + policy_params = ppo_network.policy_network.init(jax.random.PRNGKey(0)) + make_inference_fn = ppo_networks.make_inference_fn(ppo_network) + try: + model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( + make_inference_fn=make_inference_fn, + params=(None, policy_params, None), + observation_shapes={"pixels/view_0": obs_shape}, + pixel_obs_keys=("pixels/view_0",), + state_obs_key="", + normalise_channels=True, + ) + except ImportError as exc: + print(f"Skipping export: {exc}") + return + print(f"ONNX bytes: {len(model_proto.SerializeToString())}") + + +if __name__ == "__main__": + test_policy_to_onnx_export() diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py new file mode 100644 index 000000000..64ca52993 --- /dev/null +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +import logging +from collections.abc import Mapping, Sequence +from typing import Any + +import jax +import numpy as np + +try: + import tensorflow as tf + import tf2onnx + from tensorflow.keras import layers # type: ignore +except ImportError: + tf = None + layers = None + tf2onnx = None + logging.warning("TensorFlow is not installed. Skipping conversion to ONNX.") + + +logger = logging.getLogger(__name__) + + +def _get_path(tree: Any, *path: str, default: Any = None) -> Any: + cur = tree + for key in path: + if cur is None: + return default + if isinstance(cur, Mapping): + cur = cur.get(key, default) + else: + cur = getattr(cur, key, default) + if cur is default: + return default + return cur + + +def _extract_policy_params(params: Any) -> Mapping[str, Any]: + if isinstance(params, (tuple, list)): + if len(params) < 2: + raise ValueError("Expected params as (normalizer, policy, value) tuple.") + policy_params = params[1] + else: + policy_params = params + + if isinstance(policy_params, Mapping) and "params" in policy_params: + policy_params = policy_params["params"] + if not isinstance(policy_params, Mapping): + raise TypeError("Could not extract policy parameter mapping.") + return policy_params + + +def _sorted_hidden_keys(mlp_params: Mapping[str, Any]) -> list[str]: + hidden_keys = [k for k in mlp_params.keys() if k.startswith("hidden_")] + hidden_keys.sort(key=lambda x: int(x.split("_")[-1])) + return hidden_keys + + +def _infer_action_size(policy_params: Mapping[str, Any]) -> int: + mlp_params = policy_params["MLP_0"] + hidden_keys = _sorted_hidden_keys(mlp_params) + if not hidden_keys: + raise ValueError("MLP_0 has no hidden_* layers.") + last_bias = np.asarray(mlp_params[hidden_keys[-1]]["bias"]) + if last_bias.shape[-1] % 2 != 0: # type: ignore + raise ValueError( + f"Expected even policy logits dimension, got {last_bias.shape[-1]}." # type: ignore + ) + return int(last_bias.shape[-1] // 2) # type: ignore + + +def _infer_hidden_layers(policy_params: Mapping[str, Any]) -> tuple[int, ...]: + mlp_params = policy_params["MLP_0"] + hidden_keys = _sorted_hidden_keys(mlp_params) + if len(hidden_keys) < 2: + return () + return tuple( + int(np.asarray(mlp_params[k]["bias"]).shape[0]) # type: ignore + for k in hidden_keys[:-1] # type: ignore + ) + + +def _resolve_render_hw( + cfg: Any, default_h: int = 64, default_w: int = 64 +) -> tuple[int, int]: + h = _get_path(cfg, "environment", "task_params", "vision_config", "render_height") + w = _get_path(cfg, "environment", "task_params", "vision_config", "render_width") + if h is None or w is None: + return default_h, default_w + return int(h), int(w) + + +def _resolve_hidden_layers(cfg: Any, fallback: tuple[int, ...]) -> tuple[int, ...]: + from_cfg = _get_path(cfg, "agent", "policy_hidden_layer_sizes") + if from_cfg is None: + return fallback + return tuple(int(x) for x in from_cfg) + + +def _resolve_tf_activation(cfg: Any): + activation_name = _get_path(cfg, "agent", "activation") + if isinstance(activation_name, str): + activation = getattr(tf.nn, activation_name, None) + if activation is not None: + return activation + # Matches train_pick_cartesian_vision_ppo.py + return tf.nn.relu + + +if tf is not None and layers is not None: + + class VisionPPOPolicy(tf.keras.Model): + """TensorFlow equivalent of Brax VisionMLP policy head for PPO.""" + + def __init__( + self, + action_size: int, + pixel_obs_keys: Sequence[str], + hidden_layer_sizes: Sequence[int], + activation: Any, + normalise_channels: bool = True, + state_obs_key: str = "", + **kwargs, + ): + super().__init__(**kwargs) + self.action_size = action_size + self.pixel_obs_keys = tuple(pixel_obs_keys) + self.normalise_channels = normalise_channels + self.state_obs_key = state_obs_key + + self.cnn_blocks = [] + for i, _ in enumerate(self.pixel_obs_keys): + cnn = tf.keras.Sequential(name=f"CNN_{i}") + cnn.add( + layers.Conv2D( + filters=32, + kernel_size=(8, 8), + strides=(4, 4), + activation=tf.nn.relu, + use_bias=False, + name="Conv_0", + ) + ) + cnn.add( + layers.Conv2D( + filters=64, + kernel_size=(4, 4), + strides=(2, 2), + activation=tf.nn.relu, + use_bias=False, + name="Conv_1", + ) + ) + cnn.add( + layers.Conv2D( + filters=64, + kernel_size=(3, 3), + strides=(1, 1), + activation=tf.nn.relu, + use_bias=False, + name="Conv_2", + ) + ) + self.cnn_blocks.append(cnn) + + self.mlp_block = tf.keras.Sequential(name="MLP_0") + layer_sizes = list(hidden_layer_sizes) + [action_size * 2] + for i, size in enumerate(layer_sizes): + self.mlp_block.add( + layers.Dense( + units=size, + activation=activation if i < len(layer_sizes) - 1 else None, + name=f"hidden_{i}", + ) + ) + + @staticmethod + def _normalise_channels(x: tf.Tensor, eps: float = 1e-6) -> tf.Tensor: + mean = tf.reduce_mean(x, axis=(1, 2), keepdims=True) + var = tf.reduce_mean(tf.square(x - mean), axis=(1, 2), keepdims=True) + return (x - mean) * tf.math.rsqrt(var + eps) + + def call(self, obs: Mapping[str, tf.Tensor]) -> tf.Tensor: + cnn_outs = [] + for i, key in enumerate(self.pixel_obs_keys): + hidden = obs[key] + if self.normalise_channels: + hidden = self._normalise_channels(hidden) + hidden = self.cnn_blocks[i](hidden) + hidden = tf.reduce_mean(hidden, axis=(1, 2)) + cnn_outs.append(hidden) + + if self.state_obs_key: + cnn_outs.append(obs[self.state_obs_key]) + + logits = self.mlp_block(tf.concat(cnn_outs, axis=-1)) + loc, _ = tf.split(logits, 2, axis=-1) + return tf.tanh(loc) + +else: + + class VisionPPOPolicy: # type: ignore[no-redef] + pass + + +def transfer_weights( + policy_params: Mapping[str, Any], tf_model: tf.keras.Model +) -> None: + """Copies Flax PPO vision policy weights into TF model.""" + cnn_names = sorted( + [k for k in policy_params.keys() if k.startswith("CNN_")], + key=lambda x: int(x.split("_")[-1]), + ) + for i, cnn_name in enumerate(cnn_names): + tf_cnn = tf_model.get_layer(f"CNN_{i}") + for conv_name in ("Conv_0", "Conv_1", "Conv_2"): + kernel = np.asarray(policy_params[cnn_name][conv_name]["kernel"]) + tf_layer = tf_cnn.get_layer(conv_name) + tf_layer.set_weights([kernel]) + + tf_mlp = tf_model.get_layer("MLP_0") + hidden_keys = _sorted_hidden_keys(policy_params["MLP_0"]) + for hidden_name in hidden_keys: + layer_params = policy_params["MLP_0"][hidden_name] + kernel = np.asarray(layer_params["kernel"]) + bias = np.asarray(layer_params["bias"]) + tf_layer = tf_mlp.get_layer(hidden_name) + tf_layer.set_weights([kernel, bias]) + + +def convert_policy_to_onnx( + make_inference_fn: Any, + params: Any, + cfg: Any | None = None, + observation_shapes: Mapping[str, tuple[int, ...]] | None = None, + pixel_obs_keys: Sequence[str] | None = None, + state_obs_key: str = "", + normalise_channels: bool = True, +): + """Converts PPO vision policy params to ONNX via TensorFlow.""" + if tf is None or tf2onnx is None or layers is None: + raise ImportError("TensorFlow/tf2onnx is required for ONNX export.") + + policy_params = _extract_policy_params(params) + action_size = _infer_action_size(policy_params) + hidden_layers = _resolve_hidden_layers(cfg, _infer_hidden_layers(policy_params)) + activation = _resolve_tf_activation(cfg) + + cnn_names = sorted( + [k for k in policy_params.keys() if k.startswith("CNN_")], + key=lambda x: int(x.split("_")[-1]), + ) + if not cnn_names: + raise ValueError("No CNN_* modules found in policy params.") + + if pixel_obs_keys is None: + pixel_obs_keys = tuple(f"pixels/view_{i}" for i in range(len(cnn_names))) + if len(pixel_obs_keys) != len(cnn_names): + raise ValueError( + f"Expected {len(cnn_names)} pixel keys, got {len(pixel_obs_keys)}." + ) + + if observation_shapes is None: + h, w = _resolve_render_hw(cfg) + observation_shapes = {} + for i, key in enumerate(pixel_obs_keys): + channels = int( + np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore + ) + observation_shapes[key] = (h, w, channels) + + tf_policy_network = VisionPPOPolicy( + action_size=action_size, + pixel_obs_keys=pixel_obs_keys, + hidden_layer_sizes=hidden_layers, + activation=activation, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + ) + + dummy_obs = { + k: np.ones((1, *shape), dtype=np.float32) + for k, shape in observation_shapes.items() + } + tf_policy_network(dummy_obs).numpy() + transfer_weights(policy_params, tf_policy_network) + + inference_fn = make_inference_fn(params, deterministic=True) + jax_obs = {k: jax.numpy.asarray(v) for k, v in dummy_obs.items()} + jax_pred = np.asarray(inference_fn(jax_obs, jax.random.PRNGKey(0))[0][0]) + tf_pred = np.asarray(tf_policy_network(dummy_obs).numpy()[0]) + max_abs_err = float(np.max(np.abs(jax_pred - tf_pred))) + logger.info("PPO vision ONNX export sanity max_abs_err=%.6e", max_abs_err) + + tf_policy_network.output_names = ["continuous_actions"] + model_proto, _ = tf2onnx.convert.from_keras( + tf_policy_network, + input_signature=[ + { + k: tf.TensorSpec([1, *shape], tf.float32, name=k) + for k, shape in observation_shapes.items() + } + ], + opset=11, + ) + return model_proto + + +def make_franka_policy( + make_policy_fn: Any, params: Any, cfg: Any | None = None +) -> bytes: + """Builds and serializes an ONNX PPO vision policy for Franka pick tasks.""" + policy_params = _extract_policy_params(params) + cnn_names = sorted( + [k for k in policy_params.keys() if k.startswith("CNN_")], + key=lambda x: int(x.split("_")[-1]), + ) + h, w = _resolve_render_hw(cfg) + obs_shapes = {} + obs_keys = [] + for i, cnn_name in enumerate(cnn_names): + channels = int(np.asarray(policy_params[cnn_name]["Conv_0"]["kernel"]).shape[2]) # type: ignore + key = f"pixels/view_{i}" + obs_keys.append(key) + obs_shapes[key] = (h, w, channels) + + model_proto = convert_policy_to_onnx( + make_policy_fn, + params, + cfg=cfg, + observation_shapes=obs_shapes, + pixel_obs_keys=obs_keys, + state_obs_key="", + normalise_channels=True, + ) + return model_proto.SerializeToString() From f177e4ef736478e26ff8f1615a5d454f034d72ca Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:50:52 +0100 Subject: [PATCH 06/18] Add logs --- franka_experiments/ppo_vision_flax_to_onnx.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 2f2feed6b..d276f61c2 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -1,12 +1,18 @@ import jax +import numpy as np from brax.training.agents.ppo import networks as ppo_networks from brax.training.agents.ppo import networks_vision as ppo_networks_vision from flax import linen from ss2r.algorithms.ppo import franka_ppo_to_onnx +try: + import onnxruntime as ort +except ImportError: + ort = None -def test_policy_to_onnx_export(): + +def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): obs_shape = (64, 64, 1) action_size = 3 ppo_network = ppo_networks_vision.make_ppo_networks_vision( @@ -31,7 +37,39 @@ def test_policy_to_onnx_export(): except ImportError as exc: print(f"Skipping export: {exc}") return - print(f"ONNX bytes: {len(model_proto.SerializeToString())}") + + if ort is None: + print("Skipping ONNX runtime parity check: onnxruntime is not installed.") + return + + session = ort.InferenceSession( + model_proto.SerializeToString(), + providers=["CPUExecutionProvider"], + ) + jax_policy = make_inference_fn((None, policy_params, None), deterministic=True) + + rng = np.random.default_rng(0) + max_err = 0.0 + for i in range(num_tests): + obs = {"pixels/view_0": rng.standard_normal((1, *obs_shape), dtype=np.float32)} + onnx_action = session.run(["continuous_actions"], obs)[0][0] + jax_action = np.asarray( + jax_policy( + {"pixels/view_0": jax.numpy.asarray(obs["pixels/view_0"])}, + jax.random.PRNGKey(i), + )[0][0] + ) + err = float(np.max(np.abs(onnx_action - jax_action))) + max_err = max(max_err, err) + print(f"sample={i} max_abs_err={err:.6e}") + print(f" jax : {jax_action}") + print(f" onnx: {onnx_action}") + + print(f"overall max_abs_err={max_err:.6e} (atol={atol:.1e})") + if max_err > atol: + raise AssertionError( + f"ONNX/JAX mismatch: max_abs_err={max_err:.6e} > atol={atol}" + ) if __name__ == "__main__": From 542630f15849cac919479fac608f673dc8fa8d93 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 09:57:37 +0100 Subject: [PATCH 07/18] Make sure to use preprioceptive and no normalization --- franka_experiments/ppo_vision_flax_to_onnx.py | 16 ++- .../train_pick_cartesian_vision_ppo.py | 12 ++ ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 119 +++++++++++++++++- 3 files changed, 141 insertions(+), 6 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index d276f61c2..4e5fae0ca 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -22,6 +22,8 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): value_hidden_layer_sizes=(256, 256), activation=linen.relu, normalise_channels=True, + policy_obs_key="state", + value_obs_key="state", ) policy_params = ppo_network.policy_network.init(jax.random.PRNGKey(0)) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) @@ -29,9 +31,9 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( make_inference_fn=make_inference_fn, params=(None, policy_params, None), - observation_shapes={"pixels/view_0": obs_shape}, + observation_shapes={"pixels/view_0": obs_shape, "state": (10,)}, pixel_obs_keys=("pixels/view_0",), - state_obs_key="", + state_obs_key="state", normalise_channels=True, ) except ImportError as exc: @@ -51,11 +53,17 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): rng = np.random.default_rng(0) max_err = 0.0 for i in range(num_tests): - obs = {"pixels/view_0": rng.standard_normal((1, *obs_shape), dtype=np.float32)} + obs = { + "pixels/view_0": rng.standard_normal((1, *obs_shape), dtype=np.float32), + "state": rng.standard_normal((1, 10), dtype=np.float32), + } onnx_action = session.run(["continuous_actions"], obs)[0][0] jax_action = np.asarray( jax_policy( - {"pixels/view_0": jax.numpy.asarray(obs["pixels/view_0"])}, + { + "pixels/view_0": jax.numpy.asarray(obs["pixels/view_0"]), + "state": jax.numpy.asarray(obs["state"]), + }, jax.random.PRNGKey(i), )[0][0] ) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index 6945bf3c3..96b0ef71b 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -39,6 +39,11 @@ _EVAL_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( "eval_domain_randomization", True, "Enable domain randomization for evaluation" ) +_NORMALIZE_OBSERVATIONS = flags.DEFINE_boolean( + "normalize_observations", + False, + "Whether to normalize non-pixel observations (state). Mirrors franka_online by default.", +) _RENDER_WIDTH = flags.DEFINE_integer("render_width", 64, "Render width") _RENDER_HEIGHT = flags.DEFINE_integer("render_height", 64, "Render height") _USE_RASTERIZER = flags.DEFINE_boolean( @@ -120,6 +125,8 @@ def main(argv): value_hidden_layer_sizes=[256, 256], activation=linen.relu, normalise_channels=True, + policy_obs_key="state", + value_obs_key="state", ) ppo_params = manipulation_params.brax_vision_ppo_config(_BASE_CONFIG_ENV_NAME) @@ -128,8 +135,13 @@ def main(argv): ppo_params.num_eval_envs = num_envs ppo_params.episode_length = episode_length ppo_params.action_repeat = 1 + ppo_params.normalize_observations = _NORMALIZE_OBSERVATIONS.value del ppo_params.network_factory ppo_params.network_factory = network_factory + print( + "PPO normalize_observations=" + f"{ppo_params.normalize_observations} (franka_online uses false)" + ) times = [datetime.now()] diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index 64ca52993..ba91c992f 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -50,6 +50,12 @@ def _extract_policy_params(params: Any) -> Mapping[str, Any]: return policy_params +def _extract_normalizer_params(params: Any) -> Any: + if isinstance(params, (tuple, list)) and len(params) >= 1: + return params[0] + return None + + def _sorted_hidden_keys(mlp_params: Mapping[str, Any]) -> list[str]: hidden_keys = [k for k in mlp_params.keys() if k.startswith("hidden_")] hidden_keys.sort(key=lambda x: int(x.split("_")[-1])) @@ -107,6 +113,52 @@ def _resolve_tf_activation(cfg: Any): return tf.nn.relu +def _resolve_state_obs_key(cfg: Any, fallback: str = "state") -> str: + for path in ( + ("agent", "policy_obs_key"), + ("agent", "state_obs_key"), + ("agent", "network_factory", "policy_obs_key"), + ): + key = _get_path(cfg, *path) + if isinstance(key, str) and key: + return key + return fallback + + +def _resolve_normalize_observations(cfg: Any, default: bool = False) -> bool: + value = _get_path(cfg, "agent", "normalize_observations") + if value is None: + return default + return bool(value) + + +def _infer_state_obs_dim(policy_params: Mapping[str, Any]) -> int: + cnn_names = [k for k in policy_params.keys() if k.startswith("CNN_")] + if not cnn_names: + return 0 + input_dim = int(np.asarray(policy_params["MLP_0"]["hidden_0"]["kernel"]).shape[0]) # type: ignore + cnn_feat_dim = len(cnn_names) * 64 + return max(0, input_dim - cnn_feat_dim) + + +def _extract_state_mean_std( + normalizer_params: Any, state_obs_key: str +) -> tuple[np.ndarray | None, np.ndarray | None]: + if normalizer_params is None or not state_obs_key: + return None, None + mean = _get_path(normalizer_params, "mean") + std = _get_path(normalizer_params, "std") + if mean is None or std is None: + return None, None + if isinstance(mean, Mapping): + mean = mean.get(state_obs_key) + if isinstance(std, Mapping): + std = std.get(state_obs_key) + if mean is None or std is None: + return None, None + return np.asarray(mean, dtype=np.float32), np.asarray(std, dtype=np.float32) + + if tf is not None and layers is not None: class VisionPPOPolicy(tf.keras.Model): @@ -120,6 +172,8 @@ def __init__( activation: Any, normalise_channels: bool = True, state_obs_key: str = "", + state_mean: np.ndarray | None = None, + state_std: np.ndarray | None = None, **kwargs, ): super().__init__(**kwargs) @@ -127,6 +181,16 @@ def __init__( self.pixel_obs_keys = tuple(pixel_obs_keys) self.normalise_channels = normalise_channels self.state_obs_key = state_obs_key + self.state_mean = ( + tf.constant(state_mean, dtype=tf.float32) + if state_mean is not None + else None + ) + self.state_std = ( + tf.constant(state_std, dtype=tf.float32) + if state_std is not None + else None + ) self.cnn_blocks = [] for i, _ in enumerate(self.pixel_obs_keys): @@ -191,7 +255,10 @@ def call(self, obs: Mapping[str, tf.Tensor]) -> tf.Tensor: cnn_outs.append(hidden) if self.state_obs_key: - cnn_outs.append(obs[self.state_obs_key]) + state_obs = obs[self.state_obs_key] + if self.state_mean is not None and self.state_std is not None: + state_obs = (state_obs - self.state_mean) / self.state_std + cnn_outs.append(state_obs) logits = self.mlp_block(tf.concat(cnn_outs, axis=-1)) loc, _ = tf.split(logits, 2, axis=-1) @@ -236,15 +303,20 @@ def convert_policy_to_onnx( pixel_obs_keys: Sequence[str] | None = None, state_obs_key: str = "", normalise_channels: bool = True, + normalize_state_observations: bool | None = None, ): """Converts PPO vision policy params to ONNX via TensorFlow.""" if tf is None or tf2onnx is None or layers is None: raise ImportError("TensorFlow/tf2onnx is required for ONNX export.") + normalizer_params = _extract_normalizer_params(params) policy_params = _extract_policy_params(params) action_size = _infer_action_size(policy_params) hidden_layers = _resolve_hidden_layers(cfg, _infer_hidden_layers(policy_params)) activation = _resolve_tf_activation(cfg) + state_obs_dim = _infer_state_obs_dim(policy_params) + if not state_obs_key and state_obs_dim > 0: + state_obs_key = _resolve_state_obs_key(cfg) cnn_names = sorted( [k for k in policy_params.keys() if k.startswith("CNN_")], @@ -268,6 +340,31 @@ def convert_policy_to_onnx( np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore ) observation_shapes[key] = (h, w, channels) + else: + observation_shapes = dict(observation_shapes) + + if normalize_state_observations is None: + normalize_state_observations = _resolve_normalize_observations( + cfg, default=False + ) + + if state_obs_key and state_obs_key not in observation_shapes: + state_mean, _ = _extract_state_mean_std(normalizer_params, state_obs_key) + if state_mean is not None: + observation_shapes[state_obs_key] = tuple(int(x) for x in state_mean.shape) # type: ignore + elif state_obs_dim > 0: + observation_shapes[state_obs_key] = (state_obs_dim,) + else: + raise ValueError( + f"Missing observation shape for '{state_obs_key}' and state dim could not be inferred." + ) + + if normalize_state_observations: + state_mean, state_std = _extract_state_mean_std( + normalizer_params, state_obs_key + ) + else: + state_mean, state_std = None, None tf_policy_network = VisionPPOPolicy( action_size=action_size, @@ -276,6 +373,8 @@ def convert_policy_to_onnx( activation=activation, normalise_channels=normalise_channels, state_obs_key=state_obs_key, + state_mean=state_mean, + state_std=state_std, ) dummy_obs = { @@ -324,13 +423,29 @@ def make_franka_policy( obs_keys.append(key) obs_shapes[key] = (h, w, channels) + state_obs_dim = _infer_state_obs_dim(policy_params) + state_obs_key = "" + normalize_state_observations = _resolve_normalize_observations(cfg, default=False) + if state_obs_dim > 0: + state_obs_key = _resolve_state_obs_key(cfg) + if normalize_state_observations: + normalizer_params = _extract_normalizer_params(params) + state_mean, _ = _extract_state_mean_std(normalizer_params, state_obs_key) + if state_mean is not None: + obs_shapes[state_obs_key] = tuple(int(x) for x in state_mean.shape) # type: ignore + else: + obs_shapes[state_obs_key] = (state_obs_dim,) # type: ignore + else: + obs_shapes[state_obs_key] = (state_obs_dim,) # type: ignore + model_proto = convert_policy_to_onnx( make_policy_fn, params, cfg=cfg, observation_shapes=obs_shapes, pixel_obs_keys=obs_keys, - state_obs_key="", + state_obs_key=state_obs_key, normalise_channels=True, + normalize_state_observations=normalize_state_observations, ) return model_proto.SerializeToString() From ece378231eacc7c50e0137852440ba05b8bed97f Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:02:17 +0100 Subject: [PATCH 08/18] Fix conversion testing script --- ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index ba91c992f..66e5ade00 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -6,6 +6,7 @@ import jax import numpy as np +from brax.training.acme import running_statistics try: import tensorflow as tf @@ -384,7 +385,26 @@ def convert_policy_to_onnx( tf_policy_network(dummy_obs).numpy() transfer_weights(policy_params, tf_policy_network) - inference_fn = make_inference_fn(params, deterministic=True) + params_for_inference = params + if ( + isinstance(params, (tuple, list)) + and len(params) >= 2 + and params[0] is None + and state_obs_key + and state_obs_key in observation_shapes + ): + # Brax vision apply() calls normalizer_select(...) whenever state_obs_key is set. + # If caller passes params[0]=None (common in lightweight export tests), build a + # neutral running-stats container so the sanity-check inference can run. + dummy_state = { + state_obs_key: jax.numpy.zeros( + observation_shapes[state_obs_key], dtype=jax.numpy.float32 + ) + } + normalizer_params = running_statistics.init_state(dummy_state) + params_for_inference = (normalizer_params, *params[1:]) + + inference_fn = make_inference_fn(params_for_inference, deterministic=True) jax_obs = {k: jax.numpy.asarray(v) for k, v in dummy_obs.items()} jax_pred = np.asarray(inference_fn(jax_obs, jax.random.PRNGKey(0))[0][0]) tf_pred = np.asarray(tf_policy_network(dummy_obs).numpy()[0]) From 4c3685eb160813138603e58e3b6c627c8861d7a0 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:10:09 +0100 Subject: [PATCH 09/18] Revert "Fix conversion testing script" This reverts commit ece378231eacc7c50e0137852440ba05b8bed97f. --- ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index 66e5ade00..ba91c992f 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -6,7 +6,6 @@ import jax import numpy as np -from brax.training.acme import running_statistics try: import tensorflow as tf @@ -385,26 +384,7 @@ def convert_policy_to_onnx( tf_policy_network(dummy_obs).numpy() transfer_weights(policy_params, tf_policy_network) - params_for_inference = params - if ( - isinstance(params, (tuple, list)) - and len(params) >= 2 - and params[0] is None - and state_obs_key - and state_obs_key in observation_shapes - ): - # Brax vision apply() calls normalizer_select(...) whenever state_obs_key is set. - # If caller passes params[0]=None (common in lightweight export tests), build a - # neutral running-stats container so the sanity-check inference can run. - dummy_state = { - state_obs_key: jax.numpy.zeros( - observation_shapes[state_obs_key], dtype=jax.numpy.float32 - ) - } - normalizer_params = running_statistics.init_state(dummy_state) - params_for_inference = (normalizer_params, *params[1:]) - - inference_fn = make_inference_fn(params_for_inference, deterministic=True) + inference_fn = make_inference_fn(params, deterministic=True) jax_obs = {k: jax.numpy.asarray(v) for k, v in dummy_obs.items()} jax_pred = np.asarray(inference_fn(jax_obs, jax.random.PRNGKey(0))[0][0]) tf_pred = np.asarray(tf_policy_network(dummy_obs).numpy()[0]) From 2b215f0f308c613a2095d8c9b6fc6ad379d3904a Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:10:11 +0100 Subject: [PATCH 10/18] Revert "Make sure to use preprioceptive and no normalization" This reverts commit 542630f15849cac919479fac608f673dc8fa8d93. --- franka_experiments/ppo_vision_flax_to_onnx.py | 16 +-- .../train_pick_cartesian_vision_ppo.py | 12 -- ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 119 +----------------- 3 files changed, 6 insertions(+), 141 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 4e5fae0ca..d276f61c2 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -22,8 +22,6 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): value_hidden_layer_sizes=(256, 256), activation=linen.relu, normalise_channels=True, - policy_obs_key="state", - value_obs_key="state", ) policy_params = ppo_network.policy_network.init(jax.random.PRNGKey(0)) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) @@ -31,9 +29,9 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( make_inference_fn=make_inference_fn, params=(None, policy_params, None), - observation_shapes={"pixels/view_0": obs_shape, "state": (10,)}, + observation_shapes={"pixels/view_0": obs_shape}, pixel_obs_keys=("pixels/view_0",), - state_obs_key="state", + state_obs_key="", normalise_channels=True, ) except ImportError as exc: @@ -53,17 +51,11 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): rng = np.random.default_rng(0) max_err = 0.0 for i in range(num_tests): - obs = { - "pixels/view_0": rng.standard_normal((1, *obs_shape), dtype=np.float32), - "state": rng.standard_normal((1, 10), dtype=np.float32), - } + obs = {"pixels/view_0": rng.standard_normal((1, *obs_shape), dtype=np.float32)} onnx_action = session.run(["continuous_actions"], obs)[0][0] jax_action = np.asarray( jax_policy( - { - "pixels/view_0": jax.numpy.asarray(obs["pixels/view_0"]), - "state": jax.numpy.asarray(obs["state"]), - }, + {"pixels/view_0": jax.numpy.asarray(obs["pixels/view_0"])}, jax.random.PRNGKey(i), )[0][0] ) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index 96b0ef71b..6945bf3c3 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -39,11 +39,6 @@ _EVAL_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( "eval_domain_randomization", True, "Enable domain randomization for evaluation" ) -_NORMALIZE_OBSERVATIONS = flags.DEFINE_boolean( - "normalize_observations", - False, - "Whether to normalize non-pixel observations (state). Mirrors franka_online by default.", -) _RENDER_WIDTH = flags.DEFINE_integer("render_width", 64, "Render width") _RENDER_HEIGHT = flags.DEFINE_integer("render_height", 64, "Render height") _USE_RASTERIZER = flags.DEFINE_boolean( @@ -125,8 +120,6 @@ def main(argv): value_hidden_layer_sizes=[256, 256], activation=linen.relu, normalise_channels=True, - policy_obs_key="state", - value_obs_key="state", ) ppo_params = manipulation_params.brax_vision_ppo_config(_BASE_CONFIG_ENV_NAME) @@ -135,13 +128,8 @@ def main(argv): ppo_params.num_eval_envs = num_envs ppo_params.episode_length = episode_length ppo_params.action_repeat = 1 - ppo_params.normalize_observations = _NORMALIZE_OBSERVATIONS.value del ppo_params.network_factory ppo_params.network_factory = network_factory - print( - "PPO normalize_observations=" - f"{ppo_params.normalize_observations} (franka_online uses false)" - ) times = [datetime.now()] diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index ba91c992f..64ca52993 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -50,12 +50,6 @@ def _extract_policy_params(params: Any) -> Mapping[str, Any]: return policy_params -def _extract_normalizer_params(params: Any) -> Any: - if isinstance(params, (tuple, list)) and len(params) >= 1: - return params[0] - return None - - def _sorted_hidden_keys(mlp_params: Mapping[str, Any]) -> list[str]: hidden_keys = [k for k in mlp_params.keys() if k.startswith("hidden_")] hidden_keys.sort(key=lambda x: int(x.split("_")[-1])) @@ -113,52 +107,6 @@ def _resolve_tf_activation(cfg: Any): return tf.nn.relu -def _resolve_state_obs_key(cfg: Any, fallback: str = "state") -> str: - for path in ( - ("agent", "policy_obs_key"), - ("agent", "state_obs_key"), - ("agent", "network_factory", "policy_obs_key"), - ): - key = _get_path(cfg, *path) - if isinstance(key, str) and key: - return key - return fallback - - -def _resolve_normalize_observations(cfg: Any, default: bool = False) -> bool: - value = _get_path(cfg, "agent", "normalize_observations") - if value is None: - return default - return bool(value) - - -def _infer_state_obs_dim(policy_params: Mapping[str, Any]) -> int: - cnn_names = [k for k in policy_params.keys() if k.startswith("CNN_")] - if not cnn_names: - return 0 - input_dim = int(np.asarray(policy_params["MLP_0"]["hidden_0"]["kernel"]).shape[0]) # type: ignore - cnn_feat_dim = len(cnn_names) * 64 - return max(0, input_dim - cnn_feat_dim) - - -def _extract_state_mean_std( - normalizer_params: Any, state_obs_key: str -) -> tuple[np.ndarray | None, np.ndarray | None]: - if normalizer_params is None or not state_obs_key: - return None, None - mean = _get_path(normalizer_params, "mean") - std = _get_path(normalizer_params, "std") - if mean is None or std is None: - return None, None - if isinstance(mean, Mapping): - mean = mean.get(state_obs_key) - if isinstance(std, Mapping): - std = std.get(state_obs_key) - if mean is None or std is None: - return None, None - return np.asarray(mean, dtype=np.float32), np.asarray(std, dtype=np.float32) - - if tf is not None and layers is not None: class VisionPPOPolicy(tf.keras.Model): @@ -172,8 +120,6 @@ def __init__( activation: Any, normalise_channels: bool = True, state_obs_key: str = "", - state_mean: np.ndarray | None = None, - state_std: np.ndarray | None = None, **kwargs, ): super().__init__(**kwargs) @@ -181,16 +127,6 @@ def __init__( self.pixel_obs_keys = tuple(pixel_obs_keys) self.normalise_channels = normalise_channels self.state_obs_key = state_obs_key - self.state_mean = ( - tf.constant(state_mean, dtype=tf.float32) - if state_mean is not None - else None - ) - self.state_std = ( - tf.constant(state_std, dtype=tf.float32) - if state_std is not None - else None - ) self.cnn_blocks = [] for i, _ in enumerate(self.pixel_obs_keys): @@ -255,10 +191,7 @@ def call(self, obs: Mapping[str, tf.Tensor]) -> tf.Tensor: cnn_outs.append(hidden) if self.state_obs_key: - state_obs = obs[self.state_obs_key] - if self.state_mean is not None and self.state_std is not None: - state_obs = (state_obs - self.state_mean) / self.state_std - cnn_outs.append(state_obs) + cnn_outs.append(obs[self.state_obs_key]) logits = self.mlp_block(tf.concat(cnn_outs, axis=-1)) loc, _ = tf.split(logits, 2, axis=-1) @@ -303,20 +236,15 @@ def convert_policy_to_onnx( pixel_obs_keys: Sequence[str] | None = None, state_obs_key: str = "", normalise_channels: bool = True, - normalize_state_observations: bool | None = None, ): """Converts PPO vision policy params to ONNX via TensorFlow.""" if tf is None or tf2onnx is None or layers is None: raise ImportError("TensorFlow/tf2onnx is required for ONNX export.") - normalizer_params = _extract_normalizer_params(params) policy_params = _extract_policy_params(params) action_size = _infer_action_size(policy_params) hidden_layers = _resolve_hidden_layers(cfg, _infer_hidden_layers(policy_params)) activation = _resolve_tf_activation(cfg) - state_obs_dim = _infer_state_obs_dim(policy_params) - if not state_obs_key and state_obs_dim > 0: - state_obs_key = _resolve_state_obs_key(cfg) cnn_names = sorted( [k for k in policy_params.keys() if k.startswith("CNN_")], @@ -340,31 +268,6 @@ def convert_policy_to_onnx( np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore ) observation_shapes[key] = (h, w, channels) - else: - observation_shapes = dict(observation_shapes) - - if normalize_state_observations is None: - normalize_state_observations = _resolve_normalize_observations( - cfg, default=False - ) - - if state_obs_key and state_obs_key not in observation_shapes: - state_mean, _ = _extract_state_mean_std(normalizer_params, state_obs_key) - if state_mean is not None: - observation_shapes[state_obs_key] = tuple(int(x) for x in state_mean.shape) # type: ignore - elif state_obs_dim > 0: - observation_shapes[state_obs_key] = (state_obs_dim,) - else: - raise ValueError( - f"Missing observation shape for '{state_obs_key}' and state dim could not be inferred." - ) - - if normalize_state_observations: - state_mean, state_std = _extract_state_mean_std( - normalizer_params, state_obs_key - ) - else: - state_mean, state_std = None, None tf_policy_network = VisionPPOPolicy( action_size=action_size, @@ -373,8 +276,6 @@ def convert_policy_to_onnx( activation=activation, normalise_channels=normalise_channels, state_obs_key=state_obs_key, - state_mean=state_mean, - state_std=state_std, ) dummy_obs = { @@ -423,29 +324,13 @@ def make_franka_policy( obs_keys.append(key) obs_shapes[key] = (h, w, channels) - state_obs_dim = _infer_state_obs_dim(policy_params) - state_obs_key = "" - normalize_state_observations = _resolve_normalize_observations(cfg, default=False) - if state_obs_dim > 0: - state_obs_key = _resolve_state_obs_key(cfg) - if normalize_state_observations: - normalizer_params = _extract_normalizer_params(params) - state_mean, _ = _extract_state_mean_std(normalizer_params, state_obs_key) - if state_mean is not None: - obs_shapes[state_obs_key] = tuple(int(x) for x in state_mean.shape) # type: ignore - else: - obs_shapes[state_obs_key] = (state_obs_dim,) # type: ignore - else: - obs_shapes[state_obs_key] = (state_obs_dim,) # type: ignore - model_proto = convert_policy_to_onnx( make_policy_fn, params, cfg=cfg, observation_shapes=obs_shapes, pixel_obs_keys=obs_keys, - state_obs_key=state_obs_key, + state_obs_key="", normalise_channels=True, - normalize_state_observations=normalize_state_observations, ) return model_proto.SerializeToString() From bb54722d0d004e4857687dbf7f61a66f172b5fae Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:14:25 +0100 Subject: [PATCH 11/18] Save checkpoint and fix conversion errors --- franka_experiments/ppo_vision_flax_to_onnx.py | 8 ++++++++ .../train_pick_cartesian_vision_ppo.py | 18 ++++++++++++++++++ ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 3 +++ 3 files changed, 29 insertions(+) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index d276f61c2..ac34d5509 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -12,6 +12,13 @@ ort = None +def _make_ort_session_options(): + opts = ort.SessionOptions() + opts.intra_op_num_threads = 1 + opts.inter_op_num_threads = 1 + return opts + + def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): obs_shape = (64, 64, 1) action_size = 3 @@ -44,6 +51,7 @@ def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): session = ort.InferenceSession( model_proto.SerializeToString(), + sess_options=_make_ort_session_options(), providers=["CPUExecutionProvider"], ) jax_policy = make_inference_fn((None, policy_params, None), deterministic=True) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index 6945bf3c3..0c8f0db5f 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -11,6 +11,7 @@ import functools from datetime import datetime +from pathlib import Path from typing import Any import brax.training.agents.ppo.train as brax_ppo_train @@ -33,6 +34,11 @@ "num_timesteps", 7_000_000, "Total environment steps for PPO training" ) _SEED = flags.DEFINE_integer("seed", 0, "PRNG seed") +_SAVE_CHECKPOINT_PATH = flags.DEFINE_string( + "save_checkpoint_path", + "franka_experiments/checkpoints/pick_cartesian_vision_ppo", + "Directory to save Brax PPO checkpoints. Set empty string to disable.", +) _TRAIN_DOMAIN_RANDOMIZATION = flags.DEFINE_boolean( "train_domain_randomization", True, "Enable domain randomization for training" ) @@ -111,6 +117,17 @@ def main(argv): del argv print(f"Using Brax PPO trainer from: {brax_ppo_train.__file__}") + save_checkpoint_path = None + if _SAVE_CHECKPOINT_PATH.value: + save_checkpoint_path = str( + Path(_SAVE_CHECKPOINT_PATH.value).expanduser().resolve() + ) + Path(save_checkpoint_path).mkdir(parents=True, exist_ok=True) + print( + "Checkpoint saving: " + f"{save_checkpoint_path if save_checkpoint_path else 'disabled'}" + ) + train_env, episode_length = _build_env_and_cfg() num_envs = _NUM_ENVS.value @@ -149,6 +166,7 @@ def progress(num_steps, metrics): "madrona_backend": True, "progress_fn": progress, "seed": _SEED.value, + "save_checkpoint_path": save_checkpoint_path, } ) train_fn = functools.partial(brax_ppo_train.train, **train_kwargs) diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index 64ca52993..7acb691d4 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -136,6 +136,7 @@ def __init__( filters=32, kernel_size=(8, 8), strides=(4, 4), + padding="same", activation=tf.nn.relu, use_bias=False, name="Conv_0", @@ -146,6 +147,7 @@ def __init__( filters=64, kernel_size=(4, 4), strides=(2, 2), + padding="same", activation=tf.nn.relu, use_bias=False, name="Conv_1", @@ -156,6 +158,7 @@ def __init__( filters=64, kernel_size=(3, 3), strides=(1, 1), + padding="same", activation=tf.nn.relu, use_bias=False, name="Conv_2", From 287abac1a46cf5a3de4abf6c5b6915315cbd4e18 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:20:23 +0100 Subject: [PATCH 12/18] Convert from checkpint --- franka_experiments/ppo_vision_flax_to_onnx.py | 189 ++++++++++++++---- 1 file changed, 153 insertions(+), 36 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index ac34d5509..4cbfe1bee 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Mapping + import jax import numpy as np +from brax.training.acme import running_statistics +from brax.training.agents.ppo import checkpoint as ppo_checkpoint from brax.training.agents.ppo import networks as ppo_networks from brax.training.agents.ppo import networks_vision as ppo_networks_vision -from flax import linen from ss2r.algorithms.ppo import franka_ppo_to_onnx @@ -12,6 +19,43 @@ ort = None +def _resolve_checkpoint_dir(path_str: str) -> Path: + path = Path(path_str).expanduser().resolve() + if not path.exists(): + raise FileNotFoundError(f"Checkpoint path does not exist: {path}") + + config_name = "ppo_network_config.json" + if (path / config_name).exists(): + return path + + candidates = [ + p for p in path.iterdir() if p.is_dir() and (p / config_name).exists() + ] + if not candidates: + raise FileNotFoundError( + f"No checkpoint step with {config_name} found under: {path}" + ) + + numeric = [p for p in candidates if p.name.isdigit()] + if numeric: + return max(numeric, key=lambda p: int(p.name)) + return sorted(candidates)[-1] + + +def _as_obs_shapes( + observation_size: Mapping[str, object] +) -> dict[str, tuple[int, ...]]: + shapes: dict[str, tuple[int, ...]] = {} + for k, shape in observation_size.items(): + if isinstance(shape, tuple): + shapes[k] = tuple(int(x) for x in shape) + elif isinstance(shape, list): + shapes[k] = tuple(int(x) for x in shape) + else: + raise TypeError(f"Unsupported observation shape for key {k}: {shape}") + return shapes + + def _make_ort_session_options(): opts = ort.SessionOptions() opts.intra_op_num_threads = 1 @@ -19,66 +63,139 @@ def _make_ort_session_options(): return opts -def test_policy_to_onnx_export(num_tests: int = 5, atol: float = 1e-4): - obs_shape = (64, 64, 1) - action_size = 3 +def _build_random_obs( + rng: np.random.Generator, obs_shapes: Mapping[str, tuple[int, ...]] +) -> dict[str, np.ndarray]: + return { + k: rng.standard_normal((1, *shape), dtype=np.float32) + for k, shape in obs_shapes.items() + } + + +def convert_checkpoint_to_onnx( + checkpoint_path: str, + output_path: str | None = None, + num_tests: int = 0, + atol: float = 1e-4, +) -> Path: + ckpt_dir = _resolve_checkpoint_dir(checkpoint_path) + print(f"Using PPO checkpoint: {ckpt_dir}") + + params = ppo_checkpoint.load(ckpt_dir) + config = ppo_checkpoint.load_config(ckpt_dir) + config_dict = config.to_dict() + + observation_size = config_dict.get("observation_size") + if not isinstance(observation_size, dict): + raise TypeError("Checkpoint observation_size must be a mapping for vision PPO.") + obs_shapes = _as_obs_shapes(observation_size) + pixel_obs_keys = tuple(k for k in obs_shapes if k.startswith("pixels/")) + if not pixel_obs_keys: + raise ValueError( + "No pixel observation keys found in checkpoint observation_size." + ) + + network_factory_kwargs = config_dict.get("network_factory_kwargs", {}) + if not isinstance(network_factory_kwargs, dict): + network_factory_kwargs = {} + + normalize = ( + running_statistics.normalize + if bool(config_dict.get("normalize_observations", False)) + else (lambda x, y: x) + ) ppo_network = ppo_networks_vision.make_ppo_networks_vision( - observation_size={"pixels/view_0": obs_shape, "state": (10,)}, - action_size=action_size, - policy_hidden_layer_sizes=(256, 256), - value_hidden_layer_sizes=(256, 256), - activation=linen.relu, - normalise_channels=True, + observation_size=observation_size, + action_size=int(config_dict["action_size"]), + preprocess_observations_fn=normalize, + **network_factory_kwargs, ) - policy_params = ppo_network.policy_network.init(jax.random.PRNGKey(0)) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) - try: - model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( - make_inference_fn=make_inference_fn, - params=(None, policy_params, None), - observation_shapes={"pixels/view_0": obs_shape}, - pixel_obs_keys=("pixels/view_0",), - state_obs_key="", - normalise_channels=True, - ) - except ImportError as exc: - print(f"Skipping export: {exc}") - return + state_obs_key = str(network_factory_kwargs.get("policy_obs_key", "")) + model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( + make_inference_fn=make_inference_fn, + params=params, + observation_shapes=obs_shapes, + pixel_obs_keys=pixel_obs_keys, + state_obs_key=state_obs_key, + normalise_channels=True, + ) + + if output_path is None: + output_file = ckpt_dir / "policy.onnx" + else: + output_file = Path(output_path).expanduser().resolve() + output_file.parent.mkdir(parents=True, exist_ok=True) + output_file.write_bytes(model_proto.SerializeToString()) + print(f"Wrote ONNX model: {output_file}") + + if num_tests <= 0: + return output_file if ort is None: - print("Skipping ONNX runtime parity check: onnxruntime is not installed.") - return + print("onnxruntime not installed, skipping parity checks.") + return output_file session = ort.InferenceSession( model_proto.SerializeToString(), sess_options=_make_ort_session_options(), providers=["CPUExecutionProvider"], ) - jax_policy = make_inference_fn((None, policy_params, None), deterministic=True) + jax_policy = make_inference_fn(params, deterministic=True) rng = np.random.default_rng(0) max_err = 0.0 for i in range(num_tests): - obs = {"pixels/view_0": rng.standard_normal((1, *obs_shape), dtype=np.float32)} + obs = _build_random_obs(rng, obs_shapes) onnx_action = session.run(["continuous_actions"], obs)[0][0] - jax_action = np.asarray( - jax_policy( - {"pixels/view_0": jax.numpy.asarray(obs["pixels/view_0"])}, - jax.random.PRNGKey(i), - )[0][0] - ) + jax_obs = {k: jax.numpy.asarray(v) for k, v in obs.items()} + jax_action = np.asarray(jax_policy(jax_obs, jax.random.PRNGKey(i))[0][0]) err = float(np.max(np.abs(onnx_action - jax_action))) max_err = max(max_err, err) print(f"sample={i} max_abs_err={err:.6e}") - print(f" jax : {jax_action}") - print(f" onnx: {onnx_action}") print(f"overall max_abs_err={max_err:.6e} (atol={atol:.1e})") if max_err > atol: raise AssertionError( f"ONNX/JAX mismatch: max_abs_err={max_err:.6e} > atol={atol}" ) + return output_file + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert PPO vision policy checkpoint to ONNX." + ) + parser.add_argument( + "--checkpoint_path", + required=True, + help="Checkpoint step directory or parent directory containing step subdirs.", + ) + parser.add_argument( + "--output_path", + default=None, + help="Output ONNX file path (default: /policy.onnx).", + ) + parser.add_argument( + "--num_tests", + type=int, + default=0, + help="Number of ONNX-vs-JAX parity tests to run.", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-4, + help="Absolute tolerance for parity checks.", + ) + return parser.parse_args() if __name__ == "__main__": - test_policy_to_onnx_export() + args = _parse_args() + convert_checkpoint_to_onnx( + checkpoint_path=args.checkpoint_path, + output_path=args.output_path, + num_tests=args.num_tests, + atol=args.atol, + ) From b46fa8ca3ef59cf0a565cec2bbceed13ce3bd430 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:23:44 +0100 Subject: [PATCH 13/18] Ignore state --- franka_experiments/ppo_vision_flax_to_onnx.py | 12 ++++++++++-- ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 15 +++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 4cbfe1bee..5643f11ac 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -113,10 +113,18 @@ def convert_checkpoint_to_onnx( make_inference_fn = ppo_networks.make_inference_fn(ppo_network) state_obs_key = str(network_factory_kwargs.get("policy_obs_key", "")) + model_input_shapes = {k: obs_shapes[k] for k in pixel_obs_keys} + if state_obs_key: + if state_obs_key not in obs_shapes: + raise ValueError( + f"state_obs_key='{state_obs_key}' missing from checkpoint observation_size" + ) + model_input_shapes[state_obs_key] = obs_shapes[state_obs_key] + model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( make_inference_fn=make_inference_fn, params=params, - observation_shapes=obs_shapes, + observation_shapes=model_input_shapes, pixel_obs_keys=pixel_obs_keys, state_obs_key=state_obs_key, normalise_channels=True, @@ -146,7 +154,7 @@ def convert_checkpoint_to_onnx( rng = np.random.default_rng(0) max_err = 0.0 for i in range(num_tests): - obs = _build_random_obs(rng, obs_shapes) + obs = _build_random_obs(rng, model_input_shapes) onnx_action = session.run(["continuous_actions"], obs)[0][0] jax_obs = {k: jax.numpy.asarray(v) for k, v in obs.items()} jax_action = np.asarray(jax_policy(jax_obs, jax.random.PRNGKey(i))[0][0]) diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index 7acb691d4..3c8a6762a 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -272,6 +272,17 @@ def convert_policy_to_onnx( ) observation_shapes[key] = (h, w, channels) + used_observation_shapes = {k: tuple(observation_shapes[k]) for k in pixel_obs_keys} + if state_obs_key: + if state_obs_key not in observation_shapes: + raise ValueError( + f"state_obs_key='{state_obs_key}' not found in observation_shapes keys=" + f"{list(observation_shapes.keys())}" + ) + used_observation_shapes[state_obs_key] = tuple( + observation_shapes[state_obs_key] + ) + tf_policy_network = VisionPPOPolicy( action_size=action_size, pixel_obs_keys=pixel_obs_keys, @@ -283,7 +294,7 @@ def convert_policy_to_onnx( dummy_obs = { k: np.ones((1, *shape), dtype=np.float32) - for k, shape in observation_shapes.items() + for k, shape in used_observation_shapes.items() } tf_policy_network(dummy_obs).numpy() transfer_weights(policy_params, tf_policy_network) @@ -301,7 +312,7 @@ def convert_policy_to_onnx( input_signature=[ { k: tf.TensorSpec([1, *shape], tf.float32, name=k) - for k, shape in observation_shapes.items() + for k, shape in used_observation_shapes.items() } ], opset=11, From aa896ffe6d4fa54a29f910337c7a96161b366946 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:48:11 +0100 Subject: [PATCH 14/18] Test --- franka_experiments/ppo_vision_flax_to_onnx.py | 15 +++++++- .../train_pick_cartesian_vision_ppo.py | 34 +++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 5643f11ac..65778863b 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -10,6 +10,7 @@ from brax.training.agents.ppo import checkpoint as ppo_checkpoint from brax.training.agents.ppo import networks as ppo_networks from brax.training.agents.ppo import networks_vision as ppo_networks_vision +from flax import linen from ss2r.algorithms.ppo import franka_ppo_to_onnx @@ -98,17 +99,29 @@ def convert_checkpoint_to_onnx( network_factory_kwargs = config_dict.get("network_factory_kwargs", {}) if not isinstance(network_factory_kwargs, dict): network_factory_kwargs = {} + network_factory_kwargs = dict(network_factory_kwargs) + activation_name = network_factory_kwargs.pop("activation_name", "") + activation = None + if activation_name: + activation = getattr(linen, activation_name, None) + if activation is None: + raise ValueError( + f"Unsupported activation_name in checkpoint: {activation_name!r}" + ) normalize = ( running_statistics.normalize if bool(config_dict.get("normalize_observations", False)) else (lambda x, y: x) ) + ppo_network_kwargs = dict(network_factory_kwargs) + if activation is not None: + ppo_network_kwargs["activation"] = activation ppo_network = ppo_networks_vision.make_ppo_networks_vision( observation_size=observation_size, action_size=int(config_dict["action_size"]), preprocess_observations_fn=normalize, - **network_factory_kwargs, + **ppo_network_kwargs, ) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) diff --git a/franka_experiments/train_pick_cartesian_vision_ppo.py b/franka_experiments/train_pick_cartesian_vision_ppo.py index 0c8f0db5f..a84d73c1d 100644 --- a/franka_experiments/train_pick_cartesian_vision_ppo.py +++ b/franka_experiments/train_pick_cartesian_vision_ppo.py @@ -52,6 +52,36 @@ ) +def _make_ppo_networks_vision_ckpt_compatible( + observation_size, + action_size, + preprocess_observations_fn, + policy_hidden_layer_sizes=(256, 256), + value_hidden_layer_sizes=(256, 256), + normalise_channels=False, + policy_obs_key="", + value_obs_key="", + activation_name: str = "relu", +): + # FIXME: Brax checkpoint.network_config rejects non-default `activation` kwargs. + # Route activation through `activation_name` so checkpoints can be saved while + # still building the ReLU network used in this experiment. + activation = getattr(linen, activation_name, None) + if activation is None: + raise ValueError(f"Unsupported activation_name={activation_name!r}") + return ppo_networks_vision.make_ppo_networks_vision( + observation_size=observation_size, + action_size=action_size, + preprocess_observations_fn=preprocess_observations_fn, + policy_hidden_layer_sizes=policy_hidden_layer_sizes, + value_hidden_layer_sizes=value_hidden_layer_sizes, + activation=activation, + normalise_channels=normalise_channels, + policy_obs_key=policy_obs_key, + value_obs_key=value_obs_key, + ) + + def _set_nested(cfg: config_dict.ConfigDict, key: str, value: Any) -> None: keys = key.split(".") node = cfg @@ -132,10 +162,10 @@ def main(argv): num_envs = _NUM_ENVS.value network_factory = functools.partial( - ppo_networks_vision.make_ppo_networks_vision, + _make_ppo_networks_vision_ckpt_compatible, policy_hidden_layer_sizes=[256, 256], value_hidden_layer_sizes=[256, 256], - activation=linen.relu, + activation_name="relu", normalise_channels=True, ) From d556024be2361e1b4a2d717d55e21133347b8618 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 10:58:15 +0100 Subject: [PATCH 15/18] Fix shape loading --- franka_experiments/ppo_vision_flax_to_onnx.py | 55 +++++++++++++++++-- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 65778863b..10afd5593 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -52,6 +52,16 @@ def _as_obs_shapes( shapes[k] = tuple(int(x) for x in shape) elif isinstance(shape, list): shapes[k] = tuple(int(x) for x in shape) + elif isinstance(shape, Mapping) and "shape" in shape: + raw_shape = shape["shape"] + if isinstance(raw_shape, tuple): + shapes[k] = tuple(int(x) for x in raw_shape) + elif isinstance(raw_shape, list): + shapes[k] = tuple(int(x) for x in raw_shape) + else: + raise TypeError( + f"Unsupported serialized shape for key {k}: {raw_shape}" + ) else: raise TypeError(f"Unsupported observation shape for key {k}: {shape}") return shapes @@ -90,7 +100,7 @@ def convert_checkpoint_to_onnx( if not isinstance(observation_size, dict): raise TypeError("Checkpoint observation_size must be a mapping for vision PPO.") obs_shapes = _as_obs_shapes(observation_size) - pixel_obs_keys = tuple(k for k in obs_shapes if k.startswith("pixels/")) + pixel_obs_keys = tuple(k for k in observation_size if k.startswith("pixels/")) if not pixel_obs_keys: raise ValueError( "No pixel observation keys found in checkpoint observation_size." @@ -126,18 +136,37 @@ def convert_checkpoint_to_onnx( make_inference_fn = ppo_networks.make_inference_fn(ppo_network) state_obs_key = str(network_factory_kwargs.get("policy_obs_key", "")) - model_input_shapes = {k: obs_shapes[k] for k in pixel_obs_keys} + model_input_shapes: dict[str, tuple[int, ...]] = {} + needs_pixel_shape_inference = False + for key in pixel_obs_keys: + shape = obs_shapes.get(key) + if shape is None or len(shape) != 3: + needs_pixel_shape_inference = True + continue + model_input_shapes[key] = shape if state_obs_key: - if state_obs_key not in obs_shapes: + state_shape = obs_shapes.get(state_obs_key) + if state_shape is not None: + model_input_shapes[state_obs_key] = state_shape + elif not needs_pixel_shape_inference: raise ValueError( f"state_obs_key='{state_obs_key}' missing from checkpoint observation_size" ) - model_input_shapes[state_obs_key] = obs_shapes[state_obs_key] + + observation_shapes_for_export: dict[str, tuple[int, ...]] | None + if needs_pixel_shape_inference: + print( + "Checkpoint pixel observation shapes are incomplete; " + "falling back to exporter shape inference." + ) + observation_shapes_for_export = None + else: + observation_shapes_for_export = model_input_shapes model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( make_inference_fn=make_inference_fn, params=params, - observation_shapes=model_input_shapes, + observation_shapes=observation_shapes_for_export, pixel_obs_keys=pixel_obs_keys, state_obs_key=state_obs_key, normalise_channels=True, @@ -164,10 +193,24 @@ def convert_checkpoint_to_onnx( ) jax_policy = make_inference_fn(params, deterministic=True) + if observation_shapes_for_export is None: + model_input_shapes_for_test: dict[str, tuple[int, ...]] = {} + for inp in session.get_inputs(): + raw_shape = list(inp.shape)[1:] + if any((d is None or isinstance(d, str)) for d in raw_shape): + raise ValueError( + f"Cannot infer static shape for ONNX input {inp.name}: {inp.shape}. " + "Run without --num_tests or provide a checkpoint with full " + "observation shapes." + ) + model_input_shapes_for_test[inp.name] = tuple(int(d) for d in raw_shape) + else: + model_input_shapes_for_test = model_input_shapes + rng = np.random.default_rng(0) max_err = 0.0 for i in range(num_tests): - obs = _build_random_obs(rng, model_input_shapes) + obs = _build_random_obs(rng, model_input_shapes_for_test) onnx_action = session.run(["continuous_actions"], obs)[0][0] jax_obs = {k: jax.numpy.asarray(v) for k, v in obs.items()} jax_action = np.asarray(jax_policy(jax_obs, jax.random.PRNGKey(i))[0][0]) From 66b82c391e6e446226512560e7fd59bcb51d4961 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 11:05:02 +0100 Subject: [PATCH 16/18] Idk --- franka_experiments/ppo_vision_flax_to_onnx.py | 136 +++++++++++++++--- 1 file changed, 117 insertions(+), 19 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 10afd5593..8e7e35c34 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -20,6 +20,41 @@ ort = None +def _extract_policy_params(params): + if isinstance(params, (tuple, list)): + if len(params) < 2: + raise ValueError("Expected params as (normalizer, policy, value) tuple.") + policy_params = params[1] + else: + policy_params = params + if isinstance(policy_params, Mapping) and "params" in policy_params: + policy_params = policy_params["params"] + if not isinstance(policy_params, Mapping): + raise TypeError("Could not extract policy parameter mapping.") + return policy_params + + +def _sorted_cnn_names(policy_params: Mapping[str, object]) -> list[str]: + cnn_names = [k for k in policy_params.keys() if k.startswith("CNN_")] + cnn_names.sort(key=lambda x: int(x.split("_")[-1])) + return cnn_names + + +def _extract_running_stats_shapes(params) -> dict[str, tuple[int, ...]]: + if not isinstance(params, (tuple, list)) or not params: + return {} + running_stats = params[0] + mean = getattr(running_stats, "mean", None) + if not isinstance(mean, Mapping): + return {} + + shapes: dict[str, tuple[int, ...]] = {} + for key, value in mean.items(): + arr = np.asarray(value) + shapes[str(key)] = tuple(int(d) for d in arr.shape) + return shapes + + def _resolve_checkpoint_dir(path_str: str) -> Path: path = Path(path_str).expanduser().resolve() if not path.exists(): @@ -46,24 +81,18 @@ def _resolve_checkpoint_dir(path_str: str) -> Path: def _as_obs_shapes( observation_size: Mapping[str, object] ) -> dict[str, tuple[int, ...]]: + def _shape_to_tuple(raw_shape: object, key: str) -> tuple[int, ...]: + if isinstance(raw_shape, tuple): + return tuple(int(x) for x in raw_shape) + if isinstance(raw_shape, list): + return tuple(int(x) for x in raw_shape) + if isinstance(raw_shape, Mapping) and "shape" in raw_shape: + return _shape_to_tuple(raw_shape["shape"], key) + raise TypeError(f"Unsupported serialized shape for key {key}: {raw_shape}") + shapes: dict[str, tuple[int, ...]] = {} for k, shape in observation_size.items(): - if isinstance(shape, tuple): - shapes[k] = tuple(int(x) for x in shape) - elif isinstance(shape, list): - shapes[k] = tuple(int(x) for x in shape) - elif isinstance(shape, Mapping) and "shape" in shape: - raw_shape = shape["shape"] - if isinstance(raw_shape, tuple): - shapes[k] = tuple(int(x) for x in raw_shape) - elif isinstance(raw_shape, list): - shapes[k] = tuple(int(x) for x in raw_shape) - else: - raise TypeError( - f"Unsupported serialized shape for key {k}: {raw_shape}" - ) - else: - raise TypeError(f"Unsupported observation shape for key {k}: {shape}") + shapes[k] = _shape_to_tuple(shape, k) return shapes @@ -83,6 +112,51 @@ def _build_random_obs( } +def _sanitize_observation_size_for_network( + raw_observation_size: Mapping[str, object], + params, + pixel_obs_keys: tuple[str, ...], + state_obs_keys: tuple[str, ...], +) -> dict[str, tuple[int, ...]]: + obs_shapes = _as_obs_shapes(raw_observation_size) + running_stats_shapes = _extract_running_stats_shapes(params) + sanitized: dict[str, tuple[int, ...]] = {} + for key, shape in obs_shapes.items(): + if key in pixel_obs_keys: + sanitized[key] = shape + continue + running_shape = running_stats_shapes.get(key) + if running_shape is not None and (shape == (1,) or shape != running_shape): + sanitized[key] = running_shape + else: + sanitized[key] = shape + + for key in state_obs_keys: + if not key: + continue + if key not in sanitized and key in running_stats_shapes: + sanitized[key] = running_stats_shapes[key] + + policy_params = _extract_policy_params(params) + cnn_names = _sorted_cnn_names(policy_params) + + if len(cnn_names) != len(pixel_obs_keys): + raise ValueError( + f"Checkpoint CNN count ({len(cnn_names)}) does not match pixel keys " + f"({len(pixel_obs_keys)}): {pixel_obs_keys}" + ) + + for i, key in enumerate(pixel_obs_keys): + shape = sanitized.get(key) + if shape is not None and len(shape) == 3: + continue + in_channels = int( + np.asarray(policy_params[cnn_names[i]]["Conv_0"]["kernel"]).shape[2] # type: ignore[index] + ) + sanitized[key] = (64, 64, in_channels) + return sanitized + + def convert_checkpoint_to_onnx( checkpoint_path: str, output_path: str | None = None, @@ -99,7 +173,6 @@ def convert_checkpoint_to_onnx( observation_size = config_dict.get("observation_size") if not isinstance(observation_size, dict): raise TypeError("Checkpoint observation_size must be a mapping for vision PPO.") - obs_shapes = _as_obs_shapes(observation_size) pixel_obs_keys = tuple(k for k in observation_size if k.startswith("pixels/")) if not pixel_obs_keys: raise ValueError( @@ -119,23 +192,48 @@ def convert_checkpoint_to_onnx( f"Unsupported activation_name in checkpoint: {activation_name!r}" ) + state_obs_key = str(network_factory_kwargs.get("policy_obs_key", "")) + value_state_obs_key = str(network_factory_kwargs.get("value_obs_key", "")) + required_obs_keys = set(pixel_obs_keys) + if state_obs_key: + required_obs_keys.add(state_obs_key) + if value_state_obs_key: + required_obs_keys.add(value_state_obs_key) + filtered_observation_size = { + key: observation_size[key] + for key in required_obs_keys + if key in observation_size + } + missing_required_keys = sorted(required_obs_keys - set(filtered_observation_size)) + if missing_required_keys: + raise ValueError( + "Checkpoint observation_size is missing required keys: " + + ", ".join(missing_required_keys) + ) + normalize = ( running_statistics.normalize if bool(config_dict.get("normalize_observations", False)) else (lambda x, y: x) ) + observation_size_for_network = _sanitize_observation_size_for_network( + filtered_observation_size, + params, + pixel_obs_keys, + (state_obs_key, value_state_obs_key), + ) ppo_network_kwargs = dict(network_factory_kwargs) if activation is not None: ppo_network_kwargs["activation"] = activation ppo_network = ppo_networks_vision.make_ppo_networks_vision( - observation_size=observation_size, + observation_size=observation_size_for_network, action_size=int(config_dict["action_size"]), preprocess_observations_fn=normalize, **ppo_network_kwargs, ) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) - state_obs_key = str(network_factory_kwargs.get("policy_obs_key", "")) + obs_shapes = observation_size_for_network model_input_shapes: dict[str, tuple[int, ...]] = {} needs_pixel_shape_inference = False for key in pixel_obs_keys: From 240af39aa5ce980e73f1cf6a83667ec2f5c6b7a8 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 11:46:22 +0100 Subject: [PATCH 17/18] Cleanup and ignore state --- franka_experiments/ppo_vision_flax_to_onnx.py | 300 +++++++----------- ss2r/algorithms/ppo/franka_ppo_to_onnx.py | 51 +-- 2 files changed, 141 insertions(+), 210 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 8e7e35c34..0b23f3303 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -20,41 +20,6 @@ ort = None -def _extract_policy_params(params): - if isinstance(params, (tuple, list)): - if len(params) < 2: - raise ValueError("Expected params as (normalizer, policy, value) tuple.") - policy_params = params[1] - else: - policy_params = params - if isinstance(policy_params, Mapping) and "params" in policy_params: - policy_params = policy_params["params"] - if not isinstance(policy_params, Mapping): - raise TypeError("Could not extract policy parameter mapping.") - return policy_params - - -def _sorted_cnn_names(policy_params: Mapping[str, object]) -> list[str]: - cnn_names = [k for k in policy_params.keys() if k.startswith("CNN_")] - cnn_names.sort(key=lambda x: int(x.split("_")[-1])) - return cnn_names - - -def _extract_running_stats_shapes(params) -> dict[str, tuple[int, ...]]: - if not isinstance(params, (tuple, list)) or not params: - return {} - running_stats = params[0] - mean = getattr(running_stats, "mean", None) - if not isinstance(mean, Mapping): - return {} - - shapes: dict[str, tuple[int, ...]] = {} - for key, value in mean.items(): - arr = np.asarray(value) - shapes[str(key)] = tuple(int(d) for d in arr.shape) - return shapes - - def _resolve_checkpoint_dir(path_str: str) -> Path: path = Path(path_str).expanduser().resolve() if not path.exists(): @@ -64,36 +29,87 @@ def _resolve_checkpoint_dir(path_str: str) -> Path: if (path / config_name).exists(): return path - candidates = [ - p for p in path.iterdir() if p.is_dir() and (p / config_name).exists() - ] - if not candidates: + step_dirs = [p for p in path.iterdir() if p.is_dir() and (p / config_name).exists()] + if not step_dirs: raise FileNotFoundError( f"No checkpoint step with {config_name} found under: {path}" ) - numeric = [p for p in candidates if p.name.isdigit()] + numeric = [p for p in step_dirs if p.name.isdigit()] if numeric: return max(numeric, key=lambda p: int(p.name)) - return sorted(candidates)[-1] + return sorted(step_dirs)[-1] -def _as_obs_shapes( - observation_size: Mapping[str, object] -) -> dict[str, tuple[int, ...]]: - def _shape_to_tuple(raw_shape: object, key: str) -> tuple[int, ...]: - if isinstance(raw_shape, tuple): - return tuple(int(x) for x in raw_shape) - if isinstance(raw_shape, list): - return tuple(int(x) for x in raw_shape) - if isinstance(raw_shape, Mapping) and "shape" in raw_shape: - return _shape_to_tuple(raw_shape["shape"], key) - raise TypeError(f"Unsupported serialized shape for key {key}: {raw_shape}") +def _shape_tuple(raw: object) -> tuple[int, ...]: + if isinstance(raw, tuple): + return tuple(int(x) for x in raw) + if isinstance(raw, list): + return tuple(int(x) for x in raw) + if isinstance(raw, Mapping) and "shape" in raw: + return _shape_tuple(raw["shape"]) + raise TypeError(f"Unsupported shape payload: {raw}") - shapes: dict[str, tuple[int, ...]] = {} - for k, shape in observation_size.items(): - shapes[k] = _shape_to_tuple(shape, k) - return shapes + +def _extract_policy_params(params) -> Mapping[str, object]: + policy_params = params[1] if isinstance(params, (tuple, list)) else params + if isinstance(policy_params, Mapping) and "params" in policy_params: + policy_params = policy_params["params"] + if not isinstance(policy_params, Mapping): + raise TypeError("Could not extract policy parameter mapping.") + return policy_params + + +def _extract_pixel_obs_shapes( + observation_size_cfg: Mapping[str, object], + policy_params: Mapping[str, object], +) -> tuple[tuple[str, ...], dict[str, tuple[int, ...]]]: + pixel_obs_keys = tuple(k for k in observation_size_cfg if k.startswith("pixels/")) + if not pixel_obs_keys: + raise ValueError( + "No pixel observation keys found in checkpoint observation_size." + ) + + obs_shapes: dict[str, tuple[int, ...]] = {} + for i, key in enumerate(pixel_obs_keys): + try: + shape = _shape_tuple(observation_size_cfg[key]) + except Exception: + shape = () + + if len(shape) == 3: + obs_shapes[key] = shape + continue + + channels = int( + np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore[index] + ) + obs_shapes[key] = (64, 64, channels) + + return pixel_obs_keys, obs_shapes + + +def _extract_state_shape( + observation_size_cfg: Mapping[str, object], + params, +) -> tuple[int, ...]: + if isinstance(params, (tuple, list)) and params: + running_stats = params[0] + mean = getattr(running_stats, "mean", None) + if isinstance(mean, Mapping) and "state" in mean: + shape = tuple(int(d) for d in np.asarray(mean["state"]).shape) # type: ignore + if shape: + return shape + + if "state" in observation_size_cfg: + try: + shape = _shape_tuple(observation_size_cfg["state"]) + if shape: + return shape + except Exception: + pass + + return (1,) def _make_ort_session_options(): @@ -103,8 +119,21 @@ def _make_ort_session_options(): return opts +def _onnx_input_shapes(session: "ort.InferenceSession") -> dict[str, tuple[int, ...]]: + shapes: dict[str, tuple[int, ...]] = {} + for inp in session.get_inputs(): + raw_shape = list(inp.shape)[1:] + if any((d is None or isinstance(d, str)) for d in raw_shape): + raise ValueError( + f"Cannot infer static shape for ONNX input {inp.name}: {inp.shape}." + ) + shapes[inp.name] = tuple(int(d) for d in raw_shape) + return shapes + + def _build_random_obs( - rng: np.random.Generator, obs_shapes: Mapping[str, tuple[int, ...]] + rng: np.random.Generator, + obs_shapes: Mapping[str, tuple[int, ...]], ) -> dict[str, np.ndarray]: return { k: rng.standard_normal((1, *shape), dtype=np.float32) @@ -112,51 +141,6 @@ def _build_random_obs( } -def _sanitize_observation_size_for_network( - raw_observation_size: Mapping[str, object], - params, - pixel_obs_keys: tuple[str, ...], - state_obs_keys: tuple[str, ...], -) -> dict[str, tuple[int, ...]]: - obs_shapes = _as_obs_shapes(raw_observation_size) - running_stats_shapes = _extract_running_stats_shapes(params) - sanitized: dict[str, tuple[int, ...]] = {} - for key, shape in obs_shapes.items(): - if key in pixel_obs_keys: - sanitized[key] = shape - continue - running_shape = running_stats_shapes.get(key) - if running_shape is not None and (shape == (1,) or shape != running_shape): - sanitized[key] = running_shape - else: - sanitized[key] = shape - - for key in state_obs_keys: - if not key: - continue - if key not in sanitized and key in running_stats_shapes: - sanitized[key] = running_stats_shapes[key] - - policy_params = _extract_policy_params(params) - cnn_names = _sorted_cnn_names(policy_params) - - if len(cnn_names) != len(pixel_obs_keys): - raise ValueError( - f"Checkpoint CNN count ({len(cnn_names)}) does not match pixel keys " - f"({len(pixel_obs_keys)}): {pixel_obs_keys}" - ) - - for i, key in enumerate(pixel_obs_keys): - shape = sanitized.get(key) - if shape is not None and len(shape) == 3: - continue - in_channels = int( - np.asarray(policy_params[cnn_names[i]]["Conv_0"]["kernel"]).shape[2] # type: ignore[index] - ) - sanitized[key] = (64, 64, in_channels) - return sanitized - - def convert_checkpoint_to_onnx( checkpoint_path: str, output_path: str | None = None, @@ -170,103 +154,58 @@ def convert_checkpoint_to_onnx( config = ppo_checkpoint.load_config(ckpt_dir) config_dict = config.to_dict() - observation_size = config_dict.get("observation_size") - if not isinstance(observation_size, dict): + observation_size_cfg = config_dict.get("observation_size") + if not isinstance(observation_size_cfg, dict): raise TypeError("Checkpoint observation_size must be a mapping for vision PPO.") - pixel_obs_keys = tuple(k for k in observation_size if k.startswith("pixels/")) - if not pixel_obs_keys: - raise ValueError( - "No pixel observation keys found in checkpoint observation_size." - ) + + policy_params = _extract_policy_params(params) + pixel_obs_keys, network_obs_shapes = _extract_pixel_obs_shapes( + observation_size_cfg, policy_params + ) network_factory_kwargs = config_dict.get("network_factory_kwargs", {}) if not isinstance(network_factory_kwargs, dict): network_factory_kwargs = {} network_factory_kwargs = dict(network_factory_kwargs) - activation_name = network_factory_kwargs.pop("activation_name", "") - activation = None - if activation_name: - activation = getattr(linen, activation_name, None) - if activation is None: - raise ValueError( - f"Unsupported activation_name in checkpoint: {activation_name!r}" - ) - state_obs_key = str(network_factory_kwargs.get("policy_obs_key", "")) - value_state_obs_key = str(network_factory_kwargs.get("value_obs_key", "")) - required_obs_keys = set(pixel_obs_keys) - if state_obs_key: - required_obs_keys.add(state_obs_key) - if value_state_obs_key: - required_obs_keys.add(value_state_obs_key) - filtered_observation_size = { - key: observation_size[key] - for key in required_obs_keys - if key in observation_size - } - missing_required_keys = sorted(required_obs_keys - set(filtered_observation_size)) - if missing_required_keys: + activation_name = str(network_factory_kwargs.pop("activation_name", "")) + activation = getattr(linen, activation_name, None) if activation_name else None + if activation_name and activation is None: raise ValueError( - "Checkpoint observation_size is missing required keys: " - + ", ".join(missing_required_keys) + f"Unsupported activation_name in checkpoint: {activation_name!r}" ) + # Force pixel-only policy reconstruction. + network_factory_kwargs["policy_obs_key"] = "" + network_factory_kwargs["value_obs_key"] = "" + normalize = ( running_statistics.normalize if bool(config_dict.get("normalize_observations", False)) else (lambda x, y: x) ) - observation_size_for_network = _sanitize_observation_size_for_network( - filtered_observation_size, - params, - pixel_obs_keys, - (state_obs_key, value_state_obs_key), - ) + ppo_network_kwargs = dict(network_factory_kwargs) if activation is not None: ppo_network_kwargs["activation"] = activation + ppo_network = ppo_networks_vision.make_ppo_networks_vision( - observation_size=observation_size_for_network, + observation_size=network_obs_shapes, action_size=int(config_dict["action_size"]), preprocess_observations_fn=normalize, **ppo_network_kwargs, ) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) - obs_shapes = observation_size_for_network - model_input_shapes: dict[str, tuple[int, ...]] = {} - needs_pixel_shape_inference = False - for key in pixel_obs_keys: - shape = obs_shapes.get(key) - if shape is None or len(shape) != 3: - needs_pixel_shape_inference = True - continue - model_input_shapes[key] = shape - if state_obs_key: - state_shape = obs_shapes.get(state_obs_key) - if state_shape is not None: - model_input_shapes[state_obs_key] = state_shape - elif not needs_pixel_shape_inference: - raise ValueError( - f"state_obs_key='{state_obs_key}' missing from checkpoint observation_size" - ) - - observation_shapes_for_export: dict[str, tuple[int, ...]] | None - if needs_pixel_shape_inference: - print( - "Checkpoint pixel observation shapes are incomplete; " - "falling back to exporter shape inference." - ) - observation_shapes_for_export = None - else: - observation_shapes_for_export = model_input_shapes + export_obs_shapes = dict(network_obs_shapes) + export_obs_shapes["state"] = _extract_state_shape(observation_size_cfg, params) model_proto = franka_ppo_to_onnx.convert_policy_to_onnx( make_inference_fn=make_inference_fn, params=params, - observation_shapes=observation_shapes_for_export, + observation_shapes=export_obs_shapes, pixel_obs_keys=pixel_obs_keys, - state_obs_key=state_obs_key, + state_obs_key="", normalise_channels=True, ) @@ -289,28 +228,18 @@ def convert_checkpoint_to_onnx( sess_options=_make_ort_session_options(), providers=["CPUExecutionProvider"], ) + onnx_shapes = _onnx_input_shapes(session) jax_policy = make_inference_fn(params, deterministic=True) - - if observation_shapes_for_export is None: - model_input_shapes_for_test: dict[str, tuple[int, ...]] = {} - for inp in session.get_inputs(): - raw_shape = list(inp.shape)[1:] - if any((d is None or isinstance(d, str)) for d in raw_shape): - raise ValueError( - f"Cannot infer static shape for ONNX input {inp.name}: {inp.shape}. " - "Run without --num_tests or provide a checkpoint with full " - "observation shapes." - ) - model_input_shapes_for_test[inp.name] = tuple(int(d) for d in raw_shape) - else: - model_input_shapes_for_test = model_input_shapes + jax_input_keys = set(network_obs_shapes) rng = np.random.default_rng(0) max_err = 0.0 for i in range(num_tests): - obs = _build_random_obs(rng, model_input_shapes_for_test) + obs = _build_random_obs(rng, onnx_shapes) onnx_action = session.run(["continuous_actions"], obs)[0][0] - jax_obs = {k: jax.numpy.asarray(v) for k, v in obs.items()} + jax_obs = { + k: jax.numpy.asarray(v) for k, v in obs.items() if k in jax_input_keys + } jax_action = np.asarray(jax_policy(jax_obs, jax.random.PRNGKey(i))[0][0]) err = float(np.max(np.abs(onnx_action - jax_action))) max_err = max(max_err, err) @@ -321,6 +250,7 @@ def convert_checkpoint_to_onnx( raise AssertionError( f"ONNX/JAX mismatch: max_abs_err={max_err:.6e} > atol={atol}" ) + return output_file diff --git a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py index 3c8a6762a..3182d9486 100644 --- a/ss2r/algorithms/ppo/franka_ppo_to_onnx.py +++ b/ss2r/algorithms/ppo/franka_ppo_to_onnx.py @@ -126,7 +126,7 @@ def __init__( self.action_size = action_size self.pixel_obs_keys = tuple(pixel_obs_keys) self.normalise_channels = normalise_channels - self.state_obs_key = state_obs_key + del state_obs_key # state is intentionally not used by this policy. self.cnn_blocks = [] for i, _ in enumerate(self.pixel_obs_keys): @@ -193,12 +193,10 @@ def call(self, obs: Mapping[str, tf.Tensor]) -> tf.Tensor: hidden = tf.reduce_mean(hidden, axis=(1, 2)) cnn_outs.append(hidden) - if self.state_obs_key: - cnn_outs.append(obs[self.state_obs_key]) - logits = self.mlp_block(tf.concat(cnn_outs, axis=-1)) loc, _ = tf.split(logits, 2, axis=-1) - return tf.tanh(loc) + action = tf.tanh(loc) + return action else: @@ -264,24 +262,21 @@ def convert_policy_to_onnx( ) if observation_shapes is None: - h, w = _resolve_render_hw(cfg) observation_shapes = {} - for i, key in enumerate(pixel_obs_keys): - channels = int( - np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore - ) - observation_shapes[key] = (h, w, channels) - - used_observation_shapes = {k: tuple(observation_shapes[k]) for k in pixel_obs_keys} - if state_obs_key: - if state_obs_key not in observation_shapes: - raise ValueError( - f"state_obs_key='{state_obs_key}' not found in observation_shapes keys=" - f"{list(observation_shapes.keys())}" - ) - used_observation_shapes[state_obs_key] = tuple( - observation_shapes[state_obs_key] + else: + observation_shapes = dict(observation_shapes) + + h, w = _resolve_render_hw(cfg) + used_observation_shapes: dict[str, tuple[int, ...]] = {} + for i, key in enumerate(pixel_obs_keys): + if key in observation_shapes: + used_observation_shapes[key] = tuple(observation_shapes[key]) + continue + channels = int( + np.asarray(policy_params[f"CNN_{i}"]["Conv_0"]["kernel"]).shape[2] # type: ignore ) + used_observation_shapes[key] = (h, w, channels) + del state_obs_key # state is intentionally never passed to the policy network. tf_policy_network = VisionPPOPolicy( action_size=action_size, @@ -289,18 +284,24 @@ def convert_policy_to_onnx( hidden_layer_sizes=hidden_layers, activation=activation, normalise_channels=normalise_channels, - state_obs_key=state_obs_key, ) + model_input_shapes = dict(used_observation_shapes) + if "state" in observation_shapes: + model_input_shapes["state"] = tuple(observation_shapes["state"]) dummy_obs = { k: np.ones((1, *shape), dtype=np.float32) - for k, shape in used_observation_shapes.items() + for k, shape in model_input_shapes.items() } tf_policy_network(dummy_obs).numpy() transfer_weights(policy_params, tf_policy_network) inference_fn = make_inference_fn(params, deterministic=True) - jax_obs = {k: jax.numpy.asarray(v) for k, v in dummy_obs.items()} + jax_obs = { + k: jax.numpy.asarray(v) + for k, v in dummy_obs.items() + if k in used_observation_shapes + } jax_pred = np.asarray(inference_fn(jax_obs, jax.random.PRNGKey(0))[0][0]) tf_pred = np.asarray(tf_policy_network(dummy_obs).numpy()[0]) max_abs_err = float(np.max(np.abs(jax_pred - tf_pred))) @@ -312,7 +313,7 @@ def convert_policy_to_onnx( input_signature=[ { k: tf.TensorSpec([1, *shape], tf.float32, name=k) - for k, shape in used_observation_shapes.items() + for k, shape in model_input_shapes.items() } ], opset=11, From 805e197ca94213db0016df61ab5b0caa9ed4b894 Mon Sep 17 00:00:00 2001 From: Yarden As Date: Sun, 22 Feb 2026 12:21:40 +0100 Subject: [PATCH 18/18] Remove annotations --- franka_experiments/ppo_vision_flax_to_onnx.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/franka_experiments/ppo_vision_flax_to_onnx.py b/franka_experiments/ppo_vision_flax_to_onnx.py index 0b23f3303..368c1c1a4 100644 --- a/franka_experiments/ppo_vision_flax_to_onnx.py +++ b/franka_experiments/ppo_vision_flax_to_onnx.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import argparse from pathlib import Path from typing import Mapping