Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions keras_cv/src/datasets/pascal_voc/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class and instance segmentation masks.
import os.path
import random
import tarfile
import xml
from xml.etree import ElementTree

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -152,18 +152,21 @@ def _download_data_file(
if not local_dir_path:
# download to ~/.keras/datasets/fname
cache_dir = os.path.join(os.path.expanduser("~"), ".keras/datasets")
fname = os.path.join(cache_dir, os.path.basename(data_url))
fname = os.path.join(os.path.basename(data_url))
else:
# Make sure the directory exists
if not os.path.exists(local_dir_path):
os.makedirs(local_dir_path, exist_ok=True)
# download to local_dir_path/fname
fname = os.path.join(local_dir_path, os.path.basename(data_url))
fname = os.path.join(os.path.basename(data_url))
cache_dir = local_dir_path
data_directory = os.path.join(os.path.dirname(fname), extracted_dir)
if not override_extract and os.path.exists(data_directory):
logging.info("data directory %s already exist", data_directory)
return data_directory
data_file_path = keras.utils.get_file(fname=fname, origin=data_url)
data_file_path = keras.utils.get_file(
fname=fname, origin=data_url, cache_dir=cache_dir
)
# Extra the data into the same directory as the tar file.
data_directory = os.path.dirname(data_file_path)
logging.info("Extract data into %s", data_directory)
Expand All @@ -180,7 +183,7 @@ def _parse_annotation_data(annotation_file_path):

"""
with tf.io.gfile.GFile(annotation_file_path, "r") as f:
root = xml.etree.ElementTree.parse(f).getroot()
root = ElementTree.parse(f).getroot()

size = root.find("size")
width = int(size.find("width").text)
Expand Down
Loading