Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions flybody/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,36 @@ def observable_indices_in_tensor(
return sorted_obs_dict


def wing_qpos_to_conventional(model_wing_qpos: np.ndarray,
body_pitch_angle: float = 47.5,
) -> np.ndarray:
"""Transform model wing joint qpos to conventional wing kinematics definition.

Args:
model_wing_qpos: Wing MjData.qpos in radians, shape (B, 6).
Order of joints: yaw, roll, pitch, yaw, roll, pitch.
Left-right order is arbitrary.
body_pitch_angle: Body pitch angle for initial flight pose, relative to
ground, degrees. 0: horizontal body position. Default value from
https://doi.org/10.1126/science.1248955

Returns:
Wing angles transformed to conventional representation.
"""
if not isinstance(model_wing_qpos, np.ndarray):
model_wing_qpos = np.array(model_wing_qpos)
conventional = np.zeros_like(model_wing_qpos)
body_pitch_angle = np.deg2rad(body_pitch_angle)
# Yaw, doesn't require transformation.
conventional[..., [0, 3]] = model_wing_qpos[..., [0, 3]].copy()
# Roll.
conventional[..., [1, 4]] = - model_wing_qpos[..., [1, 4]]
# Pitch.
conventional[..., [2, 5]] = (
np.pi / 2 - body_pitch_angle - model_wing_qpos[..., [2, 5]])
return conventional


def get_random_policy(action_spec: 'dm_env.specs.BoundedArray',
minimum: float = -0.2,
maximum: float = 0.2) -> Callable[[Any], np.ndarray]:
Expand Down