Skip to content

Commit 224d637

Browse files
younikVincent Moens
andauthored
[CI] Fix Minari tests (#2419)
Co-authored-by: Vincent Moens <vmoens@meta.com>
1 parent 2332909 commit 224d637

File tree

2 files changed

+5
-32
lines changed

2 files changed

+5
-32
lines changed

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ dependencies:
1717
- pyyaml
1818
- scipy
1919
- hydra-core
20-
- minari[gcs]
20+
- minari[gcs,hdf5]

test/test_libs.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,34 +2823,11 @@ def _minari_selected_datasets():
28232823

28242824
torch.manual_seed(0)
28252825

2826-
# We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work
2827-
total_keys = sorted(minari.list_remote_datasets())
2828-
assert not any(
2829-
key[-2:] == "10" for key in total_keys
2830-
), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail."
2831-
total_keys_splits = [key.split("-") for key in total_keys]
2826+
total_keys = sorted(
2827+
minari.list_remote_datasets(latest_version=True, compatible_minari_version=True)
2828+
)
28322829
indices = torch.randperm(len(total_keys))[:20]
28332830
keys = [total_keys[idx] for idx in indices]
2834-
keys = [
2835-
key
2836-
for key in keys
2837-
if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
2838-
]
2839-
2840-
def _replace_with_max(key):
2841-
key_split = key.split("-")
2842-
same_entries = (
2843-
torch.tensor(
2844-
[total_key[:-1] == key_split[:-1] for total_key in total_keys_splits]
2845-
)
2846-
.nonzero()
2847-
.squeeze()
2848-
.tolist()
2849-
)
2850-
last_same_entry = same_entries[-1]
2851-
return total_keys[last_same_entry]
2852-
2853-
keys = [_replace_with_max(key) for key in keys]
28542831

28552832
assert len(keys) > 5, keys
28562833
_MINARI_DATASETS += keys
@@ -2880,12 +2857,8 @@ def test_load(self, selected_dataset, split):
28802857
break
28812858

28822859
def test_minari_preproc(self, tmpdir):
2883-
global _MINARI_DATASETS
2884-
if not _MINARI_DATASETS:
2885-
_minari_selected_datasets()
2886-
selected_dataset = _MINARI_DATASETS[0]
28872860
dataset = MinariExperienceReplay(
2888-
selected_dataset,
2861+
"D4RL/pointmaze/large-v2",
28892862
batch_size=32,
28902863
split_trajs=False,
28912864
download="force",

0 commit comments

Comments
 (0)