diff --git a/README.md b/README.md index 91269cb5..33c9821f 100644 --- a/README.md +++ b/README.md @@ -120,8 +120,69 @@ for _ in range(100): print(obs) if truncated or terminated: logger.info("Truncated or terminated!") - return + exit() ``` + +### Remote Procedure Call (RPC) Client and Server +#### Server +```python +from rcs.envs.creators import SimEnvCreator +from rcs.envs.utils import ( + default_mujoco_cameraset_cfg, + default_sim_gripper_cfg, + default_sim_robot_cfg, +) +from rcs.envs.base import ControlMode, RelativeTo +from rcs.rpc.server import RcsServer + +def run_server(): + env = SimEnvCreator()( + control_mode=ControlMode.JOINTS, + collision_guard=False, + robot_cfg=default_sim_robot_cfg(), + gripper_cfg=default_sim_gripper_cfg(), + cameras=default_mujoco_cameraset_cfg(), + max_relative_movement=0.1, + relative_to=RelativeTo.LAST_STEP, + ) + server = RcsServer(env, port=50051) + server.start() + +if __name__ == "__main__": + run_server() +``` + +#### Client +```python +import time +from python.rcs.rpc.client import RcsClient + +if __name__ == "__main__": + # Create the client (adjust host/port if needed) + client = RcsClient(host="localhost", port=50051) + + try: + print("Resetting environment...") + obs = client.reset() + print(f"Initial observation: {obs}") + + for i in range(5): + print(f"\nStep {i+1}") + # Replace with a valid action for your environment + action = 0 + obs, reward, terminated, truncated, info = client.step(action) + print(f"Obs: {obs}, Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}, Info: {info}") + if terminated or truncated: + print("Episode finished, resetting...") + obs = client.reset() + print(f"Reset observation: {obs}") + time.sleep(0.5) + finally: + print("Closing client.") + client.close() +``` + + ### Examples Checkout the python examples in the [examples](examples) folder: - [fr3_direct_control.py](examples/fr3.py) shows direct robot control with RCS's python bindings diff --git a/examples/rpc_run_client.py b/examples/rpc_run_client.py new file mode 100644 index 00000000..7cf4ce8f --- /dev/null +++ b/examples/rpc_run_client.py @@ -0,0 +1,27 @@ +import time + +from python.rcs.rpc.client import RcsClient + +if __name__ == "__main__": + # Create the client (adjust host/port if needed) + client = RcsClient(host="localhost", port=50051) + + try: + print("Resetting environment...") + obs = client.reset() + print(f"Initial observation: {obs}") + + for i in range(5): + print(f"\nStep {i+1}") + # Replace with a valid action for your environment + action = 0 + obs, reward, terminated, truncated, info = client.step(action) + print(f"Obs: {obs}, Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}, Info: {info}") + if terminated or truncated: + print("Episode finished, resetting...") + obs = client.reset() + print(f"Reset observation: {obs}") + time.sleep(0.5) + finally: + print("Closing client.") + client.close() diff --git a/examples/rpc_run_server.py b/examples/rpc_run_server.py new file mode 100644 index 00000000..74a81125 --- /dev/null +++ b/examples/rpc_run_server.py @@ -0,0 +1,26 @@ +from rcs.envs.base import ControlMode, RelativeTo +from rcs.envs.creators import SimEnvCreator +from rcs.envs.utils import ( + default_mujoco_cameraset_cfg, + default_sim_gripper_cfg, + default_sim_robot_cfg, +) +from rcs.rpc.server import RcsServer + + +def run_server(): + env = SimEnvCreator()( + control_mode=ControlMode.JOINTS, + collision_guard=False, + robot_cfg=default_sim_robot_cfg(), + gripper_cfg=default_sim_gripper_cfg(), + cameras=default_mujoco_cameraset_cfg(), + max_relative_movement=0.1, + relative_to=RelativeTo.LAST_STEP, + ) + server = RcsServer(env, port=50051) + server.start() + + +if __name__ == "__main__": + run_server() diff --git a/pyproject.toml b/pyproject.toml index 2c689d6b..aceb5db2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = ["websockets>=11.0", "tilburg-hand", "digit-interface", "ompl>=1.7.0", + "rpyc==6.0.2", ] readme = "README.md" maintainers = [ diff --git a/python/rcs/rpc/client.py b/python/rcs/rpc/client.py new file mode 100644 index 00000000..c52f2ad2 --- /dev/null +++ b/python/rcs/rpc/client.py @@ -0,0 +1,33 @@ +import gymnasium as gym +import rpyc +from rpyc.utils.classic import obtain + + +class RcsClient(gym.Env): + def __init__(self, host="localhost", port=50051): + super().__init__() + self.conn = rpyc.connect(host, port) + self.server = self.conn.root + # Optionally, fetch spaces from server if needed + # self.observation_space = ... + # self.action_space = ... + + def step(self, action): + return self.server.step(action) + + def reset(self, **kwargs): + return self.server.reset(**kwargs) + + def get_obs(self): + return self.server.get_obs() + + @property + def unwrapped(self): + return self.server.unwrapped() + + @property + def action_space(self): + return obtain(self.server.action_space()) + + def close(self): + self.conn.close() diff --git a/python/rcs/rpc/server.py b/python/rcs/rpc/server.py new file mode 100644 index 00000000..ed79cb9b --- /dev/null +++ b/python/rcs/rpc/server.py @@ -0,0 +1,49 @@ +# import wrapper +import rpyc +from gymnasium import Wrapper +from rpyc.utils.server import ThreadedServer + +rpyc.core.protocol.DEFAULT_CONFIG["allow_pickle"] = True + + +@rpyc.service +class RcsServer(Wrapper, rpyc.Service): + def __init__(self, env, host="localhost", port=50051): + super().__init__(env) + self.host = host + self.port = port + + @rpyc.exposed + def step(self, action): + """Perform a step in the environment using the Wrapper base class.""" + return super().step(action) + + @rpyc.exposed + def reset(self, **kwargs): + """Reset the environment using the Wrapper base class.""" + return super().reset(**kwargs) + + @rpyc.exposed + def get_obs(self): + """Get the current observation using the Wrapper base class if available.""" + if hasattr(super(), "get_obs"): + return super().get_obs() + if hasattr(self.env, "get_obs"): + return self.env.get_obs() + error = "The environment does not have a get_obs method." + raise NotImplementedError(error) + + @rpyc.exposed + def unwrapped(self): + """Return the unwrapped environment using the Wrapper base class.""" + return super().unwrapped + + @rpyc.exposed + def action_space(self): + """Return the action space using the Wrapper base class.""" + return super().action_space + + def start(self): + print(f"Starting RcsServer RPC (looped OneShotServer) on {self.host}:{self.port}") + t = ThreadedServer(self, port=self.port) + t.start() diff --git a/python/tests/test_rpc.py b/python/tests/test_rpc.py new file mode 100644 index 00000000..10a2efab --- /dev/null +++ b/python/tests/test_rpc.py @@ -0,0 +1,209 @@ +import multiprocessing +import os +import socket +import sys +import time +import traceback +from contextlib import suppress +from multiprocessing.context import ForkServerContext, SpawnContext +from typing import Optional, Type, Union # Add Type and Union here + +import pytest +from rcs.envs.base import ControlMode, RelativeTo +from rcs.envs.creators import SimEnvCreator +from rcs.envs.utils import default_sim_gripper_cfg, default_sim_robot_cfg +from rcs.rpc.client import RcsClient +from rcs.rpc.server import RcsServer + +HOST = "127.0.0.1" + + +def get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((HOST, 0)) + return s.getsockname()[1] + + +def wait_for_port( + host: str, + port: int, + timeout: float, + server_proc: Optional[multiprocessing.Process] = None, + err_q: Optional[multiprocessing.Queue] = None, +) -> None: + start = time.time() + last_exc = None + while time.time() - start < timeout: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(0.5) + try: + if s.connect_ex((host, port)) == 0: + return + except OSError as e: + last_exc = e + # If the server process died, surface its error immediately + if server_proc is not None and not server_proc.is_alive(): + server_err = None + if err_q is not None: + with suppress(Exception): + server_err = err_q.get_nowait() + msg = f"Server process exited early (exitcode={server_proc.exitcode})." + if server_err: + msg += f"\nServer traceback:\n{server_err}" + raise RuntimeError(msg) + time.sleep(0.2) + server_err = None + if err_q is not None: + with suppress(Exception): + server_err = err_q.get_nowait() + msg = f"Timed out waiting for {host}:{port} to open." + if last_exc: + msg += f" Last socket error: {last_exc}" + if server_proc is not None and not server_proc.is_alive(): + msg += f" Server exitcode={server_proc.exitcode}." + if server_err: + msg += f"\nServer traceback:\n{server_err}" + raise TimeoutError(msg) + + +def run_server(host: str, port: int, err_q: multiprocessing.Queue) -> None: + try: + env = SimEnvCreator()( + control_mode=ControlMode.JOINTS, + collision_guard=False, + robot_cfg=default_sim_robot_cfg(), + gripper_cfg=default_sim_gripper_cfg(), + # Disabled to avoid rendering problem in python subprocess. + # cameras=default_mujoco_cameraset_cfg(), + max_relative_movement=0.1, + relative_to=RelativeTo.LAST_STEP, + ) + # Bind explicitly to IPv4 loopback + server = RcsServer(env, host=host, port=port) + try: + server.start() + finally: + # If start returns (non-blocking implementation), keep process alive + while True: + time.sleep(1) + except Exception: + tb = "".join(traceback.format_exception(*sys.exc_info())) + with suppress(Exception): + err_q.put(tb) + sys.exit(1) + + +def _mp_context() -> Union[SpawnContext, ForkServerContext]: + # Prefer spawn to avoid fork-related issues with GL/MuJoCo/threaded libs + methods = multiprocessing.get_all_start_methods() + if "spawn" in methods: + return multiprocessing.get_context("spawn") + if "forkserver" in methods: + return multiprocessing.get_context("forkserver") + + msg = "No suitable multiprocessing context found." + raise RuntimeError(msg) + + +def _external_server_from_env() -> tuple[str, int] | None: + # Set RCS_TEST_HOST and RCS_TEST_PORT to reuse an already running server. + host = os.getenv("RCS_TEST_HOST") + port = os.getenv("RCS_TEST_PORT") + if host and port: + try: + return host, int(port) + except ValueError: + pass + # Convenience: RCS_TEST_REUSE_SERVER=1 will use HOST + default port 50055 + if os.getenv("RCS_TEST_REUSE_SERVER") == "1": + return HOST, 50055 + return None + + +def test_run_server_starts_and_stops(): + # Skip if reusing an external server + ext = _external_server_from_env() + if ext: + pytest.skip("External server reuse enabled via env; skipping spawn test.") + ctx = _mp_context() + err_q = ctx.Queue() + port = get_free_port() + server_proc = ctx.Process(target=run_server, args=(HOST, port, err_q)) + server_proc.start() + try: + wait_for_port(HOST, port, timeout=120.0, server_proc=server_proc, err_q=err_q) # type: ignore + assert server_proc.is_alive(), "Server process did not start as expected." + finally: + if server_proc.is_alive(): + server_proc.terminate() + server_proc.join(timeout=5) + assert not server_proc.is_alive(), "Server process did not terminate as expected." + + +class TestRcsClientServer: + client: RcsClient + host: str = HOST + port: int = 0 + server_proc = None + err_q: Optional[multiprocessing.Queue] = None + + @classmethod + def setup_class(cls: Type["TestRcsClientServer"]): + ext = _external_server_from_env() + if ext: + cls.host, cls.port = ext + cls.server_proc = None + cls.err_q = None + wait_for_port(cls.host, cls.port, timeout=60.0) + cls.client = RcsClient(host=cls.host, port=cls.port) + return + + ctx = _mp_context() + cls.err_q = ctx.Queue() + cls.host, cls.port = HOST, get_free_port() + cls.server_proc = ctx.Process(target=run_server, args=(cls.host, cls.port, cls.err_q)) + cls.server_proc.start() + # Wait until the server is actually listening or fail early if it crashed + wait_for_port(cls.host, cls.port, timeout=180.0, server_proc=cls.server_proc, err_q=cls.err_q) # type: ignore + cls.client = RcsClient(host=cls.host, port=cls.port) + + @classmethod + def teardown_class(cls: Type["TestRcsClientServer"]): + try: + if getattr(cls, "client", None): + cls.client.close() + finally: + if getattr(cls, "server_proc", None) and cls.server_proc and cls.server_proc.is_alive(): + cls.server_proc.terminate() + cls.server_proc.join(timeout=5) + + def test_reset(self): + obs, info = self.client.reset() + assert obs is not None, "reset did not return an observation" + + def test_step(self): + self.client.reset() + act = self.client.action_space.sample() + step_result = self.client.step(act) + assert isinstance(step_result, (tuple, list)), "step did not return a tuple or list" + + def test_get_obs(self): + self.client.reset() + obs2 = self.client.get_obs() + assert obs2 is not None, "get_obs did not return an observation" + + def test_unwrapped(self): + _ = self.client.unwrapped + + def test_close(self): + if self.client is not None: + self.client.close() + # Reconnect for further tests + wait_for_port( + self.__class__.host, + self.__class__.port, + timeout=15.0, + server_proc=self.__class__.server_proc, # type: ignore + err_q=self.__class__.err_q, + ) + self.__class__.client = RcsClient(host=self.__class__.host, port=self.__class__.port)