@@ -68,6 +68,22 @@ def log_decode_results(inputs,
68
68
identity_output = False ,
69
69
log_results = True ):
70
70
"""Log inference results."""
71
+
72
+ # TODO(lukaszkaiser) refactor this into feature_encoder
73
+ is_video = "video" in problem_name
74
+ if is_video :
75
+ def fix_and_save_video (vid , prefix ):
76
+ save_path_template = os .path .join (
77
+ model_dir , "%s_%s_%d_{}.png" % (problem_name , prefix , prediction_idx ))
78
+ # this is only required for predictions
79
+ if vid .shape [- 1 ] == 1 :
80
+ vid = np .squeeze (vid , axis = - 1 )
81
+ save_video (vid , save_path_template )
82
+ tf .logging .info ("Saving video: {}" .format (prediction_idx ))
83
+ fix_and_save_video (inputs , "inputs" )
84
+ fix_and_save_video (outputs , "outputs" )
85
+ fix_and_save_video (targets , "targets" )
86
+
71
87
is_image = "image" in problem_name
72
88
decoded_inputs = None
73
89
if is_image and save_images :
@@ -80,7 +96,7 @@ def log_decode_results(inputs,
80
96
else :
81
97
decoded_inputs = inputs_vocab .decode (_save_until_eos (inputs , is_image ))
82
98
83
- if log_results :
99
+ if log_results and not is_video :
84
100
tf .logging .info ("Inference results INPUT: %s" % decoded_inputs )
85
101
86
102
decoded_targets = None
@@ -93,8 +109,9 @@ def log_decode_results(inputs,
93
109
decoded_outputs = targets_vocab .decode (_save_until_eos (outputs , is_image ))
94
110
if targets is not None and log_results :
95
111
decoded_targets = targets_vocab .decode (_save_until_eos (targets , is_image ))
96
- tf .logging .info ("Inference results OUTPUT: %s" % decoded_outputs )
97
- if targets is not None and log_results :
112
+ if not is_video :
113
+ tf .logging .info ("Inference results OUTPUT: %s" % decoded_outputs )
114
+ if targets is not None and log_results and not is_video :
98
115
tf .logging .info ("Inference results TARGET: %s" % decoded_targets )
99
116
return decoded_inputs , decoded_outputs , decoded_targets
100
117
@@ -518,6 +535,21 @@ def _interactive_input_fn(hparams, decode_hp):
518
535
yield features
519
536
520
537
538
+ def save_video (video , save_path_template ):
539
+ """Save frames of the videos into files."""
540
+ try :
541
+ from PIL import Image # pylint: disable=g-import-not-at-top
542
+ except ImportError as e :
543
+ tf .logging .warning (
544
+ "Showing and saving an image requires PIL library to be "
545
+ "installed: %s" , e )
546
+ raise NotImplementedError ("Image display and save not implemented." )
547
+
548
+ for i , frame in enumerate (video ):
549
+ save_path = save_path_template .format (i )
550
+ Image .fromarray (np .uint8 (frame )).save (save_path )
551
+
552
+
521
553
def show_and_save_image (img , save_path ):
522
554
try :
523
555
import matplotlib .pyplot as plt # pylint: disable=g-import-not-at-top
0 commit comments