diff --git a/.github/workflows/AutoApprove.yml b/.github/workflows/AutoApprove.yml new file mode 100644 index 0000000..f06a315 --- /dev/null +++ b/.github/workflows/AutoApprove.yml @@ -0,0 +1,23 @@ +# will be removed when this project has more than one maintainers + +name: AutoApprove + +on: + pull_request: + types: [opened, reopened, synchronize, ready_for_review] + +jobs: + approve: + if: | + github.event.pull_request.user.login == 'kanarus' && + !github.event.pull_request.draft + runs-on: ubuntu-latest + permissions: + pull-requests: write + steps: + - uses: actions/checkout@v4 + - name: approve + env: + GH_TOKEN: ${{ github.token }} + run: | + gh pr review ${{ github.event.number }} --approve diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 37f4a04..c7dd7c1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -6,7 +6,7 @@ on: branches: [main, v*] jobs: - CI: + build: runs-on: ubuntu-latest strategy: @@ -16,6 +16,18 @@ jobs: steps: - uses: actions/checkout@v4 - - run: rustup update && rustup default ${{ matrix.toolchain }} - - - run: cargo test + - run: | + rustup update + rustup default ${{ matrix.toolchain }} + rustup component add rustfmt ### required for rusty_mujoco to build ### + + - name: install mujoco and set MUJOCO_DIR + run: | + mkdir -p $HOME/.mujoco + cd $HOME/.mujoco + wget https://github.com/google-deepmind/mujoco/releases/download/3.3.2/mujoco-3.3.2-linux-x86_64.tar.gz + tar -xzf mujoco-3.3.2-linux-x86_64.tar.gz + echo "MUJOCO_DIR=$HOME/.mujoco/mujoco-3.3.2" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=$HOME/.mujoco/mujoco-3.3.2/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV + + - run: cargo build diff --git a/.gitignore b/.gitignore index 96ef6c0..972b0c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -/target -Cargo.lock +**/target +**/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 21071ee..8a979e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,9 +8,9 @@ homepage = "https://crates.io/crates/oxide_control" repository = "https://github.com/rust-control/oxide_control" readme = "README.md" license = "MIT" -description = "" +description = "Rust software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo" keywords = ["mujoco", "rl", "ml", "physics", "robotics"] categories = ["science::robotics", "simulation"] [dependencies] -rusty_mujoco = { path = "../rusty_mujoco" } +rusty_mujoco = "0.1.0" diff --git a/README.md b/README.md new file mode 100644 index 0000000..94ab186 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +
+

oxide_control: The dm_control layer for Rust

+
+ +`oxide_control` is a Rust software stack for +physics-based simulation and Reinforcement Learning environments, using MuJoCo. + +This is built up on [rusty_mujoco](https://github.com/rust-control/rusty_mujoco) binding, +and provides a high-level interface similar to [dm_control](https://github.com/google-deepmind/dm_control) in Python. + +## Features + + diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..b28d8cf --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 160 diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..20df258 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,82 @@ +use rusty_mujoco::{obj, ObjectId}; + +pub enum Error { + Mujoco(::rusty_mujoco::MjError), + Mjs(String), + NameNotFound(&'static str), + PhysicsDiverged, + JointTypeNotMatch { + expected: ::rusty_mujoco::bindgen::mjtJoint, + found: ::rusty_mujoco::bindgen::mjtJoint, + }, + ActuatorStateless(ObjectId), + PluginStateless(ObjectId), + BodyNotMocap(ObjectId), +} + +impl From<::rusty_mujoco::MjError> for Error { + fn from(e: ::rusty_mujoco::MjError) -> Self { + Error::Mujoco(e) + } +} + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Mujoco(e) => write!(f, "Error::MuJoCo({e:?})"), + Error::Mjs(msg) => write!(f, "Error::Mjs({msg})"), + Error::NameNotFound(name) => write!(f, "Error::NameNotFound({name})"), + Error::PhysicsDiverged => write!(f, "Error::PhysicsDiverged"), + Error::JointTypeNotMatch { expected, found } => { + write!(f, "Error::JointTypeNotMatch(expected: {expected:?}, found: {found:?})") + } + Error::ActuatorStateless(actuator_id) => { + write!(f, "Error::ActuatorStateless({actuator_id:?})") + } + Error::PluginStateless(plugin_id) => { + write!(f, "Error::PluginStateless({plugin_id:?})") + } + Error::BodyNotMocap(body_id) => { + write!(f, "Error::BodyNotMocap({body_id:?})") + } + } + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Mujoco(e) => write!(f, "MuJoCo error: {e}"), + Error::Mjs(msg) => write!(f, "MuJoCo error: {msg}"), + Error::NameNotFound(name) => write!(f, "Given name not found: `{name}`"), + Error::PhysicsDiverged => write!(f, "Physics simulation diverged"), + Error::JointTypeNotMatch { expected, found } => { + write!(f, "Joint type mismatch: expected {expected:?}, found {found:?}") + } + Error::ActuatorStateless(actuator_id) => { + write!(f, "Actuator with ID {actuator_id:?} is stateless unexpectedly") + } + Error::PluginStateless(plugin_id) => { + write!(f, "Plugin with ID {plugin_id:?} is stateless unexpectedly") + } + Error::BodyNotMocap(body_id) => { + write!(f, "Body with ID {body_id:?} is not a mocap body") + } + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Mujoco(e) => Some(e), + Error::Mjs(_) => None, + Error::NameNotFound(_) => None, + Error::PhysicsDiverged => None, + Error::JointTypeNotMatch { .. } => None, + Error::ActuatorStateless(_) => None, + Error::PluginStateless(_) => None, + Error::BodyNotMocap(_) => None, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index b93cf3f..fa6cef6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,88 @@ -pub fn add(left: u64, right: u64) -> u64 { - left + right +pub mod error; +pub mod physics; + +pub use physics::Physics as RawPhysics; + +pub trait Physics: std::ops::DerefMut {} + +pub trait Task { + type Physics: Physics; + type Observation: Observation; + type Action: Action; + fn discount(&self) -> f64; + fn init_episode(&self, physics: &mut Self::Physics); + fn should_finish_episode(&self, observation: &Self::Observation) -> bool; + fn get_reward(&self, observation: &Self::Observation, action: &Self::Action) -> f64; +} + +pub trait Observation { + type Physics: Physics; + fn generate(physics: &Self::Physics) -> Self; +} + +pub trait Action { + type Physics: Physics; + fn apply(&self, actuators: &mut physics::Actuators<'_>); +} + +pub struct Environment { + task: T, + physics: T::Physics, +} + +impl Environment { + pub fn new(physics: T::Physics, task: T) -> Self { + Self { task, physics } + } + + pub fn task(&self) -> &T { + &self.task + } + + pub fn physics(&self) -> &T::Physics { + &self.physics + } + pub fn physics_mut(&mut self) -> &mut T::Physics { + &mut self.physics + } } -#[cfg(test)] -mod tests { - use super::*; +pub enum TimeStep { + Step { + observation: O, + reward: f64, + discount: f64, + }, + Finish { + observation: O, + reward: f64, + }, +} + +impl Environment { + pub fn reset(&mut self) -> T::Observation { + self.task.init_episode(&mut self.physics); + T::Observation::generate(&self.physics) + } + + pub fn step(&mut self, action: T::Action) -> TimeStep { + action.apply(&mut self.physics.actuators()); + self.physics.step(); + + let observation = T::Observation::generate(&self.physics); + let reward = self.task.get_reward(&observation, &action); - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + if self.task.should_finish_episode(&observation) { + TimeStep::Finish { + observation, + reward, + } + } else { + TimeStep::Step { + observation, + reward, + discount: self.task.discount(), + } + } } } diff --git a/src/physics.rs b/src/physics.rs new file mode 100644 index 0000000..baec50f --- /dev/null +++ b/src/physics.rs @@ -0,0 +1,171 @@ +pub use rusty_mujoco as binding; +pub use binding::{mjModel, mjData, ObjectId, obj, Joint, joint, mjMAXVAL, mjMINVAL}; + +use crate::error::Error; + +pub struct Physics { + model: mjModel, + data: mjData, +} + +impl Physics { + pub fn from_xml(xml_path: impl AsRef) -> Result { + let model = binding::mj_loadXML(xml_path.as_ref().to_str().unwrap())?; + let data = binding::mj_makeData(&model); + Ok(Self { model, data }) + } + + pub fn from_xml_string(xml_string: impl Into) -> Result { + let mut spec = binding::mj_parseXMLString(xml_string.into())?; + let model = binding::mj_compile(&mut spec) + .ok_or_else(|| Error::Mjs(binding::mjs_getError(&mut spec).unwrap_or_else(String::new)))?; + let data = binding::mj_makeData(&model); + Ok(Self { model, data }) + } + + pub fn model(&self) -> &mjModel { + &self.model + } + + pub fn data(&self) -> &mjData { + &self.data + } + pub fn data_mut(&mut self) -> &mut mjData { + &mut self.data + } + + pub fn model_data(&self) -> (&mjModel, &mjData) { + (&self.model, &self.data) + } + pub fn model_datamut(&mut self) -> (&mjModel, &mut mjData) { + (&self.model, &mut self.data) + } + + pub fn step(&mut self) { + rusty_mujoco::mj_step(&self.model, &mut self.data); + } + + pub fn forward(&mut self) { + rusty_mujoco::mj_forward(&self.model, &mut self.data); + } + + pub fn reset(&mut self) { + rusty_mujoco::mj_resetData(&self.model, &mut self.data); + } + + pub fn object_id(&self, name: &str) -> Option> { + self.model.object_id(name) + } + + pub fn object_name(&self, id: ObjectId) -> String { + binding::mj_id2name::(&self.model, id) + } +} + +pub struct Actuators<'a> { + physics: &'a mut Physics, +} +impl<'a> Actuators<'a> { + pub fn set(&mut self, id: ObjectId, control: f64) { + self.physics.set_ctrl(id, control); + } +} +impl Physics { + pub fn actuators(&mut self) -> Actuators<'_> { + Actuators { + physics: self, + } + } +} + +impl Physics { + pub fn time(&self) -> f64 { + self.data.time() + } + pub fn set_time(&mut self, time: f64) { + self.data.set_time(time); + } + + pub fn ctrl(&self, id: ObjectId) -> f64 { + self.data.ctrl(id) + } + pub fn set_ctrl(&mut self, id: ObjectId, value: f64) { + self.data.set_ctrl(id, value); + } + + pub fn act(&self, id: ObjectId) -> Option { + self.data.act(id, &self.model) + } + /// Set the actuator activation value. `None` when the actuator is stateless. + pub fn set_act(&mut self, id: ObjectId, value: f64) -> Option<()> { + self.data.set_act(id, value, &self.model) + } + + pub fn qpos(&self, id: ObjectId) -> J::Qpos { + self.data.qpos(id, &self.model) + } + pub fn set_qpos(&mut self, id: ObjectId, qpos: J::Qpos) { + self.data.set_qpos(id, qpos, &self.model); + } + + pub fn qvel(&self, id: ObjectId) -> J::Qvel { + self.data.qvel(id, &self.model) + } + pub fn set_qvel(&mut self, id: ObjectId, qvel: J::Qvel) { + self.data.set_qvel(id, qvel, &self.model); + } + + pub fn qacc_warmstart(&self, id: ObjectId) -> f64 { + self.data.qacc_warmstart(id) + } + pub fn set_qacc_warmstart(&mut self, id: ObjectId, value: f64) { + self.data.set_qacc_warmstart(id, value); + } + + pub fn plugin_state(&self, id: ObjectId) -> Option { + self.data.plugin_state(id, &self.model) + } + /// Set the plugin state. Returns `None` if the plugin does not have a state. + pub fn set_plugin_state(&mut self, id: ObjectId, value: f64) -> Option<()> { + self.data.set_plugin_state(id, value, &self.model) + } + + pub fn qfrc_applied(&self, id: ObjectId) -> f64 { + self.data.qfrc_applied(id) + } + pub fn set_qfrc_applied(&mut self, id: ObjectId, value: f64) { + self.data.set_qfrc_applied(id, value); + } + + pub fn xfrc_applied(&self, id: ObjectId) -> [f64; 6] { + self.data.xfrc_applied(id) + } + pub fn set_xfrc_applied(&mut self, id: ObjectId, value: [f64; 6]) { + self.data.set_xfrc_applied(id, value); + } + + pub fn eq_active(&self, id: ObjectId) -> bool { + self.data.eq_active(id) + } + pub fn set_eq_active(&mut self, id: ObjectId, value: bool) { + self.data.set_eq_active(id, value); + } + + /// `None` when the body is not a mocap body. + pub fn mocap_pos(&self, id: ObjectId) -> Option<[f64; 3]> { + self.data.mocap_pos(id, &self.model) + } + /// Set the mocap position. Returns `None` if the body is not a mocap body. + pub fn set_mocap_pos(&mut self, id: ObjectId, pos: [f64; 3]) -> Option<()> { + self.data.set_mocap_pos(id, pos, &self.model) + } + + /// `None` when the body is not a mocap body. + pub fn mocap_quat(&self, id: ObjectId) -> Option<[f64; 4]> { + self.data.mocap_quat(id, &self.model) + } + /// Set the mocap quaternion. Returns `None` if the body is not a mocap body. + pub fn set_mocap_quat(&mut self, id: ObjectId, quat: [f64; 4]) -> Option<()> { + self.data.set_mocap_quat(id, quat, &self.model) + } +}