Skip to content

Commit 98b392c

Browse files
authored
Merge pull request #244 from RobotControlStack/juelg/fixes
A set of small fixes
2 parents a804132 + d3276f1 commit 98b392c

File tree

5 files changed

+45
-17
lines changed

5 files changed

+45
-17
lines changed

extensions/rcs_fr3/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ version = "0.5.2"
88
description="RCS libfranka integration"
99
dependencies = [
1010
"rcs>=0.5.2",
11-
"frankik @ git+https://github.com/juelg/frankik",
11+
"frankik",
1212
]
1313
readme = "README.md"
1414
maintainers = [

extensions/rcs_fr3/src/rcs_fr3/creators.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131

3232

3333
class FrankIK(Kinematics):
34-
def __init__(self, allow_elbow_flips: bool = False):
34+
def __init__(self, global_solution: bool = False):
3535
Kinematics.__init__(self)
36-
self.allow_elbow_flips = allow_elbow_flips
36+
self.global_solution = global_solution
3737
self.kin = FrankaKinematics(robot_type="fr3")
3838

3939
def forward(self, q0: np.ndarray[tuple[typing.Literal[7]], np.dtype[np.float64]], tcp_offset: Pose) -> Pose: # type: ignore
@@ -43,10 +43,7 @@ def forward(self, q0: np.ndarray[tuple[typing.Literal[7]], np.dtype[np.float64]]
4343
def inverse( # type: ignore
4444
self, pose: Pose, q0: np.ndarray[tuple[typing.Literal[7]], np.dtype[np.float64]], tcp_offset: Pose
4545
) -> np.ndarray[tuple[typing.Literal[7]], np.dtype[np.float64]] | None:
46-
tcp_offset = self.kin.FrankaHandTCPOffset
47-
return self.kin.inverse(
48-
pose.pose_matrix(), q0, tcp_offset.pose_matrix(), allow_elbow_flips=self.allow_elbow_flips
49-
)
46+
return self.kin.inverse(pose.pose_matrix(), q0, tcp_offset.pose_matrix(), global_solution=self.global_solution)
5047

5148

5249
# FYI: this needs to be in global namespace to avoid auto garbage collection issues

extensions/rcs_fr3/src/rcs_fr3/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool,
3030
def get_obs(self, obs: dict | None = None) -> dict[str, Any]:
3131
if obs is None:
3232
obs = dict(self.unwrapped.get_obs())
33-
robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state())
34-
obs["robot_state"] = vars(robot_state.robot_state)
33+
# robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state())
34+
# obs["robot_state"] = vars(robot_state.robot_state)
3535
return obs
3636

3737
def reset(

extensions/rcs_tacto/src/rcs_tacto/tacto_wrapper.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def reset(
8282
self.tacto_last_render = -1 # Reset last render time
8383
colors, depths = self.tacto_sensor.render(self.model, self.data)
8484
for site, color, depth in zip(self.tacto_sites, colors, depths, strict=False):
85-
obs.setdefault("tacto", {}).setdefault(site, {}).setdefault("rgb", {})["data"] = color
85+
obs.setdefault("frames", {}).setdefault(f"tactile_{site}", {}).setdefault("rgb", {})["data"] = color
8686
if self.enable_depth:
87-
obs.setdefault("tacto", {}).setdefault(site, {}).setdefault("depth", {})["data"] = depth
87+
obs.setdefault("frames", {}).setdefault(f"tactile_{site}", {}).setdefault("depth", {})["data"] = depth
8888
return obs, info
8989

9090
def step(self, action: dict[str, Any]):
@@ -94,7 +94,9 @@ def step(self, action: dict[str, Any]):
9494
self.tacto_sensor.updateGUI(colors, depths) if self.visualize else None
9595
self.tacto_last_render = self.data.time
9696
for site, color, depth in zip(self.tacto_sites, colors, depths, strict=False):
97-
obs.setdefault("tacto", {}).setdefault(site, {}).setdefault("rgb", {})["data"] = color
97+
obs.setdefault("frames", {}).setdefault(f"tactile_{site}", {}).setdefault("rgb", {})["data"] = color
9898
if self.enable_depth:
99-
obs.setdefault("tacto", {}).setdefault(site, {}).setdefault("depth", {})["data"] = depth
99+
obs.setdefault("frames", {}).setdefault(f"tactile_{site}", {}).setdefault("depth", {})[
100+
"data"
101+
] = depth
100102
return obs, reward, done, truncated, info

python/rcs/envs/storage_wrapper.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import operator
23
from concurrent.futures import ThreadPoolExecutor, wait
34
from itertools import chain
@@ -10,6 +11,7 @@
1011
import pyarrow as pa
1112
import pyarrow.dataset as ds
1213
import simplejpeg
14+
from PIL import Image
1315

1416

1517
class StorageWrapper(gym.Wrapper):
@@ -22,7 +24,7 @@ def __init__(
2224
instruction: str,
2325
batch_size: int = 32,
2426
schema: Optional[pa.Schema] = None,
25-
start_record: bool = False,
27+
always_record: bool = False,
2628
basename_template: Optional[str] = None,
2729
max_rows_per_group: Optional[int] = None,
2830
max_rows_per_file: Optional[int] = None,
@@ -72,9 +74,10 @@ def __init__(
7274
self.max_rows_per_file = max_rows_per_file
7375
self.buffer: list[dict[str, Any]] = []
7476
self.step_cnt = 0
75-
self._pause = True
77+
self._pause = not always_record
78+
self.always_record = always_record
7679
self.instruction = instruction
77-
self._success = start_record
80+
self._success = False
7881
self._prev_action = None
7982
self.thread_pool = ThreadPoolExecutor()
8083
self.queue: Queue[pa.Table | pa.RecordBatch] = Queue(maxsize=2)
@@ -122,6 +125,7 @@ def _flatten_arrays(self, d: dict[str, Any]):
122125
d.update(updates)
123126

124127
def _encode_images(self, obs: dict[str, Any]):
128+
# images
125129
_ = [
126130
*self.thread_pool.map(
127131
lambda cam: operator.setitem(
@@ -133,6 +137,31 @@ def _encode_images(self, obs: dict[str, Any]):
133137
)
134138
]
135139

140+
# depth
141+
def to_tiff(depth_data):
142+
img_bytes = io.BytesIO()
143+
Image.fromarray(
144+
depth_data.reshape((depth_data.shape[0], depth_data.shape[1])),
145+
).save(
146+
img_bytes, format="TIFF"
147+
) # type: ignore
148+
return img_bytes.getvalue() # type: ignore
149+
150+
_ = [
151+
*self.thread_pool.map(
152+
lambda cam: (
153+
operator.setitem(
154+
obs["frames"][cam]["depth"],
155+
"data",
156+
to_tiff(obs["frames"][cam]["depth"]["data"]),
157+
)
158+
if "depth" in obs["frames"][cam]
159+
else None
160+
),
161+
obs["frames"],
162+
)
163+
]
164+
136165
def step(self, action):
137166
# NOTE: expects the observation to be a dictionary
138167
if self._writer_future.done():
@@ -179,7 +208,7 @@ def start_record(self):
179208
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
180209
if len(self.buffer) > 0:
181210
self._flush()
182-
self._pause = True
211+
self._pause = not self.always_record
183212
self._success = False
184213
self._prev_action = None
185214
obs, info = self.env.reset()

0 commit comments

Comments
 (0)