Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6fb1537

Browse files
Ryan SepassiCopybara-Service
authored andcommitted
Lazy load PIL
PiperOrigin-RevId: 200762820
1 parent 9598fa2 commit 6fb1537

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

tensor2tensor/data_generators/bair_robot_pushing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import os
2828
import tarfile
2929
import numpy as np
30-
from PIL import Image
3130

3231
from tensor2tensor.data_generators import generator_utils
3332
from tensor2tensor.data_generators import problem
@@ -40,6 +39,12 @@
4039
"http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar")
4140

4241

42+
# Lazy load PIL.Image
43+
def PIL_Image(): # pylint: disable=invalid-name
44+
from PIL import Image # pylint: disable=g-import-not-at-top
45+
return Image
46+
47+
4348
@registry.register_problem
4449
class VideoBairRobotPushing(video_utils.VideoProblem):
4550
"""Berkeley (BAIR) robot pushing dataset."""
@@ -85,7 +90,7 @@ def parse_frames(self, filenames):
8590
state_name = state_key.format(i)
8691

8792
byte_str = x.features.feature[image_name].bytes_list.value[0]
88-
img = Image.frombytes(
93+
img = PIL_Image().frombytes(
8994
"RGB", (self.frame_width, self.frame_height), byte_str)
9095
arr = np.array(img.getdata())
9196
frame = arr.reshape(

tensor2tensor/data_generators/google_robot_pushing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import io
2828
import os
2929
import numpy as np
30-
from PIL import Image
3130

3231
from tensor2tensor.data_generators import generator_utils
3332
from tensor2tensor.data_generators import problem
@@ -42,6 +41,12 @@
4241
DATA_TEST_NOVEL = (5, "/push_testnovel/push_testnovel.tfrecord-{:05d}-of-00005")
4342

4443

44+
# Lazy load PIL.Image
45+
def PIL_Image(): # pylint: disable=invalid-name
46+
from PIL import Image # pylint: disable=g-import-not-at-top
47+
return Image
48+
49+
4550
@registry.register_problem
4651
class VideoGoogleRobotPushing(video_utils.VideoProblem):
4752
"""Google robot pushing dataset."""
@@ -90,10 +95,10 @@ def parse_frames(self, filename):
9095
state_name = state_key.format(i)
9196

9297
byte_str = x.features.feature[image_name].bytes_list.value[0]
93-
img = Image.open(io.BytesIO(byte_str))
98+
img = PIL_Image().open(io.BytesIO(byte_str))
9499
# The original images are much bigger than 64x64
95100
img = img.resize((self.frame_width, self.frame_height),
96-
resample=Image.BILINEAR)
101+
resample=PIL_Image().BILINEAR)
97102
arr = np.array(img.getdata())
98103
frame = arr.reshape(
99104
self.frame_width, self.frame_height, self.num_channels)

0 commit comments

Comments
 (0)