Skip to content

Commit 06d8e64

Browse files
authored
[BUG] Fix data loader (#2810)
* Update _data_loaders.py * Update _data_loaders.py * stop deleting directories
1 parent 748d044 commit 06d8e64

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

aeon/datasets/_data_loaders.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,17 @@ def _download_and_extract(url, extract_path=None):
468468
extract_path = os.path.join(extract_path, "%s/" % file_name.split(".")[0])
469469

470470
try:
471-
if not os.path.exists(extract_path):
471+
already_exists = os.path.exists(extract_path)
472+
if not already_exists:
472473
os.makedirs(extract_path)
473474
zipfile.ZipFile(zip_file_name, "r").extractall(extract_path)
474475
shutil.rmtree(dl_dir)
475476
return extract_path
476477
except zipfile.BadZipFile:
477478
shutil.rmtree(dl_dir)
478-
if os.path.exists(extract_path):
479-
shutil.rmtree(extract_path)
479+
if not already_exists:
480+
if os.path.exists(extract_path):
481+
shutil.rmtree(extract_path)
480482
raise zipfile.BadZipFile(
481483
"Could not unzip dataset. Please make sure the URL is valid."
482484
)
@@ -546,7 +548,7 @@ def _load_tsc_dataset(
546548
except zipfile.BadZipFile as e:
547549
raise ValueError(
548550
f"Invalid dataset name ={name} is not available on extract path ="
549-
f"{extract_path}. Nor is it available on {url}",
551+
f"{extract_path} nor is it available on {url}",
550552
) from e
551553

552554
return _load_saved_dataset(
@@ -1342,7 +1344,7 @@ def load_classification(
13421344
try_zenodo = False
13431345
error_str = (
13441346
f"Invalid dataset name ={name} that is not available on extract path "
1345-
f"={extract_path}. Nor is it available on "
1347+
f"={extract_path} nor is it available on "
13461348
f"https://timeseriesclassification.com/ or zenodo."
13471349
)
13481350
try:

aeon/datasets/tests/test_data_loaders.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import shutil
77
import tempfile
88
from urllib.error import URLError
9+
from zipfile import BadZipFile
910

1011
import numpy as np
1112
import pandas as pd
@@ -24,6 +25,7 @@
2425
from aeon.datasets._data_loaders import (
2526
CONNECTION_ERRORS,
2627
_alias_datatype_check,
28+
_download_and_extract,
2729
_get_channel_strings,
2830
_load_data,
2931
_load_header_info,
@@ -551,3 +553,21 @@ def test_load_tsc_dataset():
551553
assert isinstance(X, np.ndarray) and isinstance(y, np.ndarray)
552554
with pytest.raises(ValueError, match="Invalid dataset name"):
553555
_load_tsc_dataset("FOO", split="TEST", extract_path=tmp)
556+
557+
558+
@pytest.mark.skipif(
559+
PR_TESTING,
560+
reason="Only run on overnights because of intermittent fail for read/write",
561+
)
562+
@pytest.mark.xfail(raises=(URLError, TimeoutError, ConnectionError))
563+
def test_download_and_extract():
564+
"""Test that the function does not delete a directory if already present."""
565+
name = "Foo"
566+
with tempfile.TemporaryDirectory() as tmp:
567+
extract_path = os.path.join(tmp, name)
568+
os.makedirs(extract_path)
569+
url = "https://timeseriesclassification.com/aeon-toolkit/%s.zip" % name
570+
try:
571+
_download_and_extract(url, extract_path=extract_path)
572+
except BadZipFile:
573+
assert os.path.exists(extract_path)

0 commit comments

Comments
 (0)