@@ -2823,34 +2823,11 @@ def _minari_selected_datasets():
2823
2823
2824
2824
torch .manual_seed (0 )
2825
2825
2826
- # We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work
2827
- total_keys = sorted (minari .list_remote_datasets ())
2828
- assert not any (
2829
- key [- 2 :] == "10" for key in total_keys
2830
- ), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail."
2831
- total_keys_splits = [key .split ("-" ) for key in total_keys ]
2826
+ total_keys = sorted (
2827
+ minari .list_remote_datasets (latest_version = True , compatible_minari_version = True )
2828
+ )
2832
2829
indices = torch .randperm (len (total_keys ))[:20 ]
2833
2830
keys = [total_keys [idx ] for idx in indices ]
2834
- keys = [
2835
- key
2836
- for key in keys
2837
- if "=0.4" in minari .list_remote_datasets ()[key ]["minari_version" ]
2838
- ]
2839
-
2840
- def _replace_with_max (key ):
2841
- key_split = key .split ("-" )
2842
- same_entries = (
2843
- torch .tensor (
2844
- [total_key [:- 1 ] == key_split [:- 1 ] for total_key in total_keys_splits ]
2845
- )
2846
- .nonzero ()
2847
- .squeeze ()
2848
- .tolist ()
2849
- )
2850
- last_same_entry = same_entries [- 1 ]
2851
- return total_keys [last_same_entry ]
2852
-
2853
- keys = [_replace_with_max (key ) for key in keys ]
2854
2831
2855
2832
assert len (keys ) > 5 , keys
2856
2833
_MINARI_DATASETS += keys
@@ -2880,12 +2857,8 @@ def test_load(self, selected_dataset, split):
2880
2857
break
2881
2858
2882
2859
def test_minari_preproc (self , tmpdir ):
2883
- global _MINARI_DATASETS
2884
- if not _MINARI_DATASETS :
2885
- _minari_selected_datasets ()
2886
- selected_dataset = _MINARI_DATASETS [0 ]
2887
2860
dataset = MinariExperienceReplay (
2888
- selected_dataset ,
2861
+ "D4RL/pointmaze/large-v2" ,
2889
2862
batch_size = 32 ,
2890
2863
split_trajs = False ,
2891
2864
download = "force" ,
0 commit comments