diff --git a/environment.yml b/environment.yml index ddfa06e..acd3c9b 100644 --- a/environment.yml +++ b/environment.yml @@ -21,7 +21,7 @@ dependencies: - future==0.18.2 - gdown==4.7.1 - glfw==2.6.2 - - gym==0.21.0 + - gym - gymnasium==0.29.1 - h5py==3.9.0 - huggingface-hub==0.16.4 diff --git a/scripts/download_libero_datasets.py b/scripts/download_libero_datasets.py index 6f87488..22e1864 100644 --- a/scripts/download_libero_datasets.py +++ b/scripts/download_libero_datasets.py @@ -10,7 +10,7 @@ def parse_args(): parser.add_argument( "--download-dir", type=str, - default="./data/", + default="./data/libero/", ) parser.add_argument( "--datasets", diff --git a/scripts/preprocess_libero.py b/scripts/preprocess_libero.py index e817a74..7e7cfe3 100644 --- a/scripts/preprocess_libero.py +++ b/scripts/preprocess_libero.py @@ -271,8 +271,8 @@ def main(root, save, suite, skip_exist): # load task name embeddings task_bert_embs_dict = get_task_bert_embs(root) - for source_h5 in os.listdir(suite_dir): - source_h5_path = os.path.join(suite_dir, source_h5) + for source_h5_path in glob(os.path.join(suite_dir, "*.hdf5")): + source_h5 = os.path.basename(source_h5_path) file_name = source_h5.split('.')[0] task_name = get_task_name_from_file_name(file_name)