diff --git a/.github/workflows/AutoApprove.yml b/.github/workflows/AutoApprove.yml
new file mode 100644
index 0000000..f06a315
--- /dev/null
+++ b/.github/workflows/AutoApprove.yml
@@ -0,0 +1,23 @@
+# will be removed when this project has more than one maintainers
+
+name: AutoApprove
+
+on:
+ pull_request:
+ types: [opened, reopened, synchronize, ready_for_review]
+
+jobs:
+ approve:
+ if: |
+ github.event.pull_request.user.login == 'kanarus' &&
+ !github.event.pull_request.draft
+ runs-on: ubuntu-latest
+ permissions:
+ pull-requests: write
+ steps:
+ - uses: actions/checkout@v4
+ - name: approve
+ env:
+ GH_TOKEN: ${{ github.token }}
+ run: |
+ gh pr review ${{ github.event.number }} --approve
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 37f4a04..c7dd7c1 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -6,7 +6,7 @@ on:
branches: [main, v*]
jobs:
- CI:
+ build:
runs-on: ubuntu-latest
strategy:
@@ -16,6 +16,18 @@ jobs:
steps:
- uses: actions/checkout@v4
- - run: rustup update && rustup default ${{ matrix.toolchain }}
-
- - run: cargo test
+ - run: |
+ rustup update
+ rustup default ${{ matrix.toolchain }}
+ rustup component add rustfmt ### required for rusty_mujoco to build ###
+
+ - name: install mujoco and set MUJOCO_DIR
+ run: |
+ mkdir -p $HOME/.mujoco
+ cd $HOME/.mujoco
+ wget https://github.com/google-deepmind/mujoco/releases/download/3.3.2/mujoco-3.3.2-linux-x86_64.tar.gz
+ tar -xzf mujoco-3.3.2-linux-x86_64.tar.gz
+ echo "MUJOCO_DIR=$HOME/.mujoco/mujoco-3.3.2" >> $GITHUB_ENV
+ echo "LD_LIBRARY_PATH=$HOME/.mujoco/mujoco-3.3.2/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV
+
+ - run: cargo build
diff --git a/.gitignore b/.gitignore
index 96ef6c0..972b0c4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,2 @@
-/target
-Cargo.lock
+**/target
+**/Cargo.lock
diff --git a/Cargo.toml b/Cargo.toml
index 21071ee..8a979e8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,9 +8,9 @@ homepage = "https://crates.io/crates/oxide_control"
repository = "https://github.com/rust-control/oxide_control"
readme = "README.md"
license = "MIT"
-description = ""
+description = "Rust software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo"
keywords = ["mujoco", "rl", "ml", "physics", "robotics"]
categories = ["science::robotics", "simulation"]
[dependencies]
-rusty_mujoco = { path = "../rusty_mujoco" }
+rusty_mujoco = "0.1.0"
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..94ab186
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+
+
+`oxide_control` is a Rust software stack for
+physics-based simulation and Reinforcement Learning environments, using MuJoCo.
+
+This is built up on [rusty_mujoco](https://github.com/rust-control/rusty_mujoco) binding,
+and provides a high-level interface similar to [dm_control](https://github.com/google-deepmind/dm_control) in Python.
+
+## Features
+
+
diff --git a/rustfmt.toml b/rustfmt.toml
new file mode 100644
index 0000000..b28d8cf
--- /dev/null
+++ b/rustfmt.toml
@@ -0,0 +1 @@
+max_width = 160
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..20df258
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,82 @@
+use rusty_mujoco::{obj, ObjectId};
+
+pub enum Error {
+ Mujoco(::rusty_mujoco::MjError),
+ Mjs(String),
+ NameNotFound(&'static str),
+ PhysicsDiverged,
+ JointTypeNotMatch {
+ expected: ::rusty_mujoco::bindgen::mjtJoint,
+ found: ::rusty_mujoco::bindgen::mjtJoint,
+ },
+ ActuatorStateless(ObjectId),
+ PluginStateless(ObjectId),
+ BodyNotMocap(ObjectId),
+}
+
+impl From<::rusty_mujoco::MjError> for Error {
+ fn from(e: ::rusty_mujoco::MjError) -> Self {
+ Error::Mujoco(e)
+ }
+}
+
+impl std::fmt::Debug for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Error::Mujoco(e) => write!(f, "Error::MuJoCo({e:?})"),
+ Error::Mjs(msg) => write!(f, "Error::Mjs({msg})"),
+ Error::NameNotFound(name) => write!(f, "Error::NameNotFound({name})"),
+ Error::PhysicsDiverged => write!(f, "Error::PhysicsDiverged"),
+ Error::JointTypeNotMatch { expected, found } => {
+ write!(f, "Error::JointTypeNotMatch(expected: {expected:?}, found: {found:?})")
+ }
+ Error::ActuatorStateless(actuator_id) => {
+ write!(f, "Error::ActuatorStateless({actuator_id:?})")
+ }
+ Error::PluginStateless(plugin_id) => {
+ write!(f, "Error::PluginStateless({plugin_id:?})")
+ }
+ Error::BodyNotMocap(body_id) => {
+ write!(f, "Error::BodyNotMocap({body_id:?})")
+ }
+ }
+ }
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Error::Mujoco(e) => write!(f, "MuJoCo error: {e}"),
+ Error::Mjs(msg) => write!(f, "MuJoCo error: {msg}"),
+ Error::NameNotFound(name) => write!(f, "Given name not found: `{name}`"),
+ Error::PhysicsDiverged => write!(f, "Physics simulation diverged"),
+ Error::JointTypeNotMatch { expected, found } => {
+ write!(f, "Joint type mismatch: expected {expected:?}, found {found:?}")
+ }
+ Error::ActuatorStateless(actuator_id) => {
+ write!(f, "Actuator with ID {actuator_id:?} is stateless unexpectedly")
+ }
+ Error::PluginStateless(plugin_id) => {
+ write!(f, "Plugin with ID {plugin_id:?} is stateless unexpectedly")
+ }
+ Error::BodyNotMocap(body_id) => {
+ write!(f, "Body with ID {body_id:?} is not a mocap body")
+ }
+ }
+ }
+}
+
+impl std::error::Error for Error {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ match self {
+ Error::Mujoco(e) => Some(e),
+ Error::Mjs(_) => None,
+ Error::NameNotFound(_) => None,
+ Error::PhysicsDiverged => None,
+ Error::JointTypeNotMatch { .. } => None,
+ Error::ActuatorStateless(_) => None,
+ Error::PluginStateless(_) => None,
+ Error::BodyNotMocap(_) => None,
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index b93cf3f..fa6cef6 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,14 +1,88 @@
-pub fn add(left: u64, right: u64) -> u64 {
- left + right
+pub mod error;
+pub mod physics;
+
+pub use physics::Physics as RawPhysics;
+
+pub trait Physics: std::ops::DerefMut {}
+
+pub trait Task {
+ type Physics: Physics;
+ type Observation: Observation;
+ type Action: Action;
+ fn discount(&self) -> f64;
+ fn init_episode(&self, physics: &mut Self::Physics);
+ fn should_finish_episode(&self, observation: &Self::Observation) -> bool;
+ fn get_reward(&self, observation: &Self::Observation, action: &Self::Action) -> f64;
+}
+
+pub trait Observation {
+ type Physics: Physics;
+ fn generate(physics: &Self::Physics) -> Self;
+}
+
+pub trait Action {
+ type Physics: Physics;
+ fn apply(&self, actuators: &mut physics::Actuators<'_>);
+}
+
+pub struct Environment {
+ task: T,
+ physics: T::Physics,
+}
+
+impl Environment {
+ pub fn new(physics: T::Physics, task: T) -> Self {
+ Self { task, physics }
+ }
+
+ pub fn task(&self) -> &T {
+ &self.task
+ }
+
+ pub fn physics(&self) -> &T::Physics {
+ &self.physics
+ }
+ pub fn physics_mut(&mut self) -> &mut T::Physics {
+ &mut self.physics
+ }
}
-#[cfg(test)]
-mod tests {
- use super::*;
+pub enum TimeStep {
+ Step {
+ observation: O,
+ reward: f64,
+ discount: f64,
+ },
+ Finish {
+ observation: O,
+ reward: f64,
+ },
+}
+
+impl Environment {
+ pub fn reset(&mut self) -> T::Observation {
+ self.task.init_episode(&mut self.physics);
+ T::Observation::generate(&self.physics)
+ }
+
+ pub fn step(&mut self, action: T::Action) -> TimeStep {
+ action.apply(&mut self.physics.actuators());
+ self.physics.step();
+
+ let observation = T::Observation::generate(&self.physics);
+ let reward = self.task.get_reward(&observation, &action);
- #[test]
- fn it_works() {
- let result = add(2, 2);
- assert_eq!(result, 4);
+ if self.task.should_finish_episode(&observation) {
+ TimeStep::Finish {
+ observation,
+ reward,
+ }
+ } else {
+ TimeStep::Step {
+ observation,
+ reward,
+ discount: self.task.discount(),
+ }
+ }
}
}
diff --git a/src/physics.rs b/src/physics.rs
new file mode 100644
index 0000000..baec50f
--- /dev/null
+++ b/src/physics.rs
@@ -0,0 +1,171 @@
+pub use rusty_mujoco as binding;
+pub use binding::{mjModel, mjData, ObjectId, obj, Joint, joint, mjMAXVAL, mjMINVAL};
+
+use crate::error::Error;
+
+pub struct Physics {
+ model: mjModel,
+ data: mjData,
+}
+
+impl Physics {
+ pub fn from_xml(xml_path: impl AsRef) -> Result {
+ let model = binding::mj_loadXML(xml_path.as_ref().to_str().unwrap())?;
+ let data = binding::mj_makeData(&model);
+ Ok(Self { model, data })
+ }
+
+ pub fn from_xml_string(xml_string: impl Into) -> Result {
+ let mut spec = binding::mj_parseXMLString(xml_string.into())?;
+ let model = binding::mj_compile(&mut spec)
+ .ok_or_else(|| Error::Mjs(binding::mjs_getError(&mut spec).unwrap_or_else(String::new)))?;
+ let data = binding::mj_makeData(&model);
+ Ok(Self { model, data })
+ }
+
+ pub fn model(&self) -> &mjModel {
+ &self.model
+ }
+
+ pub fn data(&self) -> &mjData {
+ &self.data
+ }
+ pub fn data_mut(&mut self) -> &mut mjData {
+ &mut self.data
+ }
+
+ pub fn model_data(&self) -> (&mjModel, &mjData) {
+ (&self.model, &self.data)
+ }
+ pub fn model_datamut(&mut self) -> (&mjModel, &mut mjData) {
+ (&self.model, &mut self.data)
+ }
+
+ pub fn step(&mut self) {
+ rusty_mujoco::mj_step(&self.model, &mut self.data);
+ }
+
+ pub fn forward(&mut self) {
+ rusty_mujoco::mj_forward(&self.model, &mut self.data);
+ }
+
+ pub fn reset(&mut self) {
+ rusty_mujoco::mj_resetData(&self.model, &mut self.data);
+ }
+
+ pub fn object_id(&self, name: &str) -> Option> {
+ self.model.object_id(name)
+ }
+
+ pub fn object_name(&self, id: ObjectId) -> String {
+ binding::mj_id2name::(&self.model, id)
+ }
+}
+
+pub struct Actuators<'a> {
+ physics: &'a mut Physics,
+}
+impl<'a> Actuators<'a> {
+ pub fn set(&mut self, id: ObjectId, control: f64) {
+ self.physics.set_ctrl(id, control);
+ }
+}
+impl Physics {
+ pub fn actuators(&mut self) -> Actuators<'_> {
+ Actuators {
+ physics: self,
+ }
+ }
+}
+
+impl Physics {
+ pub fn time(&self) -> f64 {
+ self.data.time()
+ }
+ pub fn set_time(&mut self, time: f64) {
+ self.data.set_time(time);
+ }
+
+ pub fn ctrl(&self, id: ObjectId) -> f64 {
+ self.data.ctrl(id)
+ }
+ pub fn set_ctrl(&mut self, id: ObjectId, value: f64) {
+ self.data.set_ctrl(id, value);
+ }
+
+ pub fn act(&self, id: ObjectId) -> Option {
+ self.data.act(id, &self.model)
+ }
+ /// Set the actuator activation value. `None` when the actuator is stateless.
+ pub fn set_act(&mut self, id: ObjectId, value: f64) -> Option<()> {
+ self.data.set_act(id, value, &self.model)
+ }
+
+ pub fn qpos(&self, id: ObjectId) -> J::Qpos {
+ self.data.qpos(id, &self.model)
+ }
+ pub fn set_qpos(&mut self, id: ObjectId, qpos: J::Qpos) {
+ self.data.set_qpos(id, qpos, &self.model);
+ }
+
+ pub fn qvel(&self, id: ObjectId) -> J::Qvel {
+ self.data.qvel(id, &self.model)
+ }
+ pub fn set_qvel(&mut self, id: ObjectId, qvel: J::Qvel) {
+ self.data.set_qvel(id, qvel, &self.model);
+ }
+
+ pub fn qacc_warmstart(&self, id: ObjectId) -> f64 {
+ self.data.qacc_warmstart(id)
+ }
+ pub fn set_qacc_warmstart(&mut self, id: ObjectId, value: f64) {
+ self.data.set_qacc_warmstart(id, value);
+ }
+
+ pub fn plugin_state(&self, id: ObjectId) -> Option {
+ self.data.plugin_state(id, &self.model)
+ }
+ /// Set the plugin state. Returns `None` if the plugin does not have a state.
+ pub fn set_plugin_state(&mut self, id: ObjectId, value: f64) -> Option<()> {
+ self.data.set_plugin_state(id, value, &self.model)
+ }
+
+ pub fn qfrc_applied(&self, id: ObjectId) -> f64 {
+ self.data.qfrc_applied(id)
+ }
+ pub fn set_qfrc_applied(&mut self, id: ObjectId, value: f64) {
+ self.data.set_qfrc_applied(id, value);
+ }
+
+ pub fn xfrc_applied(&self, id: ObjectId) -> [f64; 6] {
+ self.data.xfrc_applied(id)
+ }
+ pub fn set_xfrc_applied(&mut self, id: ObjectId, value: [f64; 6]) {
+ self.data.set_xfrc_applied(id, value);
+ }
+
+ pub fn eq_active(&self, id: ObjectId) -> bool {
+ self.data.eq_active(id)
+ }
+ pub fn set_eq_active(&mut self, id: ObjectId, value: bool) {
+ self.data.set_eq_active(id, value);
+ }
+
+ /// `None` when the body is not a mocap body.
+ pub fn mocap_pos(&self, id: ObjectId) -> Option<[f64; 3]> {
+ self.data.mocap_pos(id, &self.model)
+ }
+ /// Set the mocap position. Returns `None` if the body is not a mocap body.
+ pub fn set_mocap_pos(&mut self, id: ObjectId, pos: [f64; 3]) -> Option<()> {
+ self.data.set_mocap_pos(id, pos, &self.model)
+ }
+
+ /// `None` when the body is not a mocap body.
+ pub fn mocap_quat(&self, id: ObjectId) -> Option<[f64; 4]> {
+ self.data.mocap_quat(id, &self.model)
+ }
+ /// Set the mocap quaternion. Returns `None` if the body is not a mocap body.
+ pub fn set_mocap_quat(&mut self, id: ObjectId, quat: [f64; 4]) -> Option<()> {
+ self.data.set_mocap_quat(id, quat, &self.model)
+ }
+}