1+ import io
12import operator
23from concurrent .futures import ThreadPoolExecutor , wait
34from itertools import chain
1011import pyarrow as pa
1112import pyarrow .dataset as ds
1213import simplejpeg
14+ from PIL import Image
1315
1416
1517class StorageWrapper (gym .Wrapper ):
@@ -22,7 +24,7 @@ def __init__(
2224 instruction : str ,
2325 batch_size : int = 32 ,
2426 schema : Optional [pa .Schema ] = None ,
25- start_record : bool = False ,
27+ always_record : bool = False ,
2628 basename_template : Optional [str ] = None ,
2729 max_rows_per_group : Optional [int ] = None ,
2830 max_rows_per_file : Optional [int ] = None ,
@@ -72,9 +74,10 @@ def __init__(
7274 self .max_rows_per_file = max_rows_per_file
7375 self .buffer : list [dict [str , Any ]] = []
7476 self .step_cnt = 0
75- self ._pause = True
77+ self ._pause = not always_record
78+ self .always_record = always_record
7679 self .instruction = instruction
77- self ._success = start_record
80+ self ._success = False
7881 self ._prev_action = None
7982 self .thread_pool = ThreadPoolExecutor ()
8083 self .queue : Queue [pa .Table | pa .RecordBatch ] = Queue (maxsize = 2 )
@@ -122,6 +125,7 @@ def _flatten_arrays(self, d: dict[str, Any]):
122125 d .update (updates )
123126
124127 def _encode_images (self , obs : dict [str , Any ]):
128+ # images
125129 _ = [
126130 * self .thread_pool .map (
127131 lambda cam : operator .setitem (
@@ -133,6 +137,31 @@ def _encode_images(self, obs: dict[str, Any]):
133137 )
134138 ]
135139
140+ # depth
141+ def to_tiff (depth_data ):
142+ img_bytes = io .BytesIO ()
143+ Image .fromarray (
144+ depth_data .reshape ((depth_data .shape [0 ], depth_data .shape [1 ])),
145+ ).save (
146+ img_bytes , format = "TIFF"
147+ ) # type: ignore
148+ return img_bytes .getvalue () # type: ignore
149+
150+ _ = [
151+ * self .thread_pool .map (
152+ lambda cam : (
153+ operator .setitem (
154+ obs ["frames" ][cam ]["depth" ],
155+ "data" ,
156+ to_tiff (obs ["frames" ][cam ]["depth" ]["data" ]),
157+ )
158+ if "depth" in obs ["frames" ][cam ]
159+ else None
160+ ),
161+ obs ["frames" ],
162+ )
163+ ]
164+
136165 def step (self , action ):
137166 # NOTE: expects the observation to be a dictionary
138167 if self ._writer_future .done ():
@@ -179,7 +208,7 @@ def start_record(self):
179208 def reset (self , * , seed : int | None = None , options : dict [str , Any ] | None = None ):
180209 if len (self .buffer ) > 0 :
181210 self ._flush ()
182- self ._pause = True
211+ self ._pause = not self . always_record
183212 self ._success = False
184213 self ._prev_action = None
185214 obs , info = self .env .reset ()
0 commit comments