Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 233fdf4

Browse files
T2T TeamCopybara-Service
authored andcommitted
Adding tests for next_frame basic training step.
PiperOrigin-RevId: 200764433
1 parent e96ca3a commit 233fdf4

File tree

2 files changed

+106
-2
lines changed

2 files changed

+106
-2
lines changed

tensor2tensor/data_generators/video_generated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ def extra_reading_spec(self):
9595
def hparams(self, defaults, unused_model_hparams):
9696
p = defaults
9797
p.input_modality = {
98-
"inputs": ("video:raw", 256),
98+
"inputs": ("video", 256),
9999
"input_frame_number": ("symbol:identity", 1)
100100
}
101-
p.target_modality = ("video:raw", 256)
101+
p.target_modality = ("video", 256)
102102

103103
@staticmethod
104104
def get_circle(x, y, z, c, s):
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Basic tests for video prediction models."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
import numpy as np
21+
22+
from tensor2tensor.data_generators import video_generated # pylint: disable=unused-import
23+
from tensor2tensor.models.research import next_frame
24+
from tensor2tensor.utils import registry
25+
26+
import tensorflow as tf
27+
28+
29+
class NextFrameTest(tf.test.TestCase):
30+
31+
def TestVideoModel(self,
32+
in_frames,
33+
out_frames,
34+
hparams,
35+
model,
36+
expected_last_dim):
37+
38+
x = np.random.random_integers(0, high=255, size=(8, in_frames, 64, 64, 3))
39+
y = np.random.random_integers(0, high=255, size=(8, out_frames, 64, 64, 3))
40+
41+
hparams.video_num_input_frames = in_frames
42+
hparams.video_num_target_frames = out_frames
43+
44+
problem = registry.problem("video_stochastic_shapes10k")
45+
p_hparams = problem.get_hparams(hparams)
46+
hparams.problem = problem
47+
hparams.problem_hparams = p_hparams
48+
49+
with self.test_session() as session:
50+
features = {
51+
"inputs": tf.constant(x, dtype=tf.int32),
52+
"targets": tf.constant(y, dtype=tf.int32),
53+
}
54+
model = model(
55+
hparams, tf.estimator.ModeKeys.TRAIN)
56+
logits, _ = model(features)
57+
session.run(tf.global_variables_initializer())
58+
res = session.run(logits)
59+
expected_shape = y.shape + (expected_last_dim,)
60+
self.assertEqual(res.shape, expected_shape)
61+
62+
def TestBasicModel(self, in_frames, out_frames):
63+
self.TestVideoModel(
64+
in_frames,
65+
out_frames,
66+
next_frame.next_frame(),
67+
next_frame.NextFrameBasic,
68+
256)
69+
70+
def testBasicModelSingleInputFrameSingleOutputFrames(self):
71+
self.TestBasicModel(1, 1)
72+
73+
def testBasicModelSingleInputFrameMultiOutputFrames(self):
74+
self.TestBasicModel(1, 6)
75+
76+
def testBasicModelMultiInputFrameSingleOutputFrames(self):
77+
self.TestBasicModel(4, 1)
78+
79+
def testBasicModelMultiInputFrameMultiOutputFrames(self):
80+
self.TestBasicModel(7, 5)
81+
82+
def TestStochasticModel(self, in_frames, out_frames):
83+
self.TestVideoModel(
84+
in_frames,
85+
out_frames,
86+
next_frame.next_frame_stochastic(),
87+
next_frame.NextFrameStochastic,
88+
1)
89+
90+
def testStochasticModelSingleInputFrameSingleOutputFrames(self):
91+
self.TestStochasticModel(1, 1)
92+
93+
def testStochasticModelSingleInputFrameMultiOutputFrames(self):
94+
self.TestStochasticModel(1, 6)
95+
96+
def testStochasticModelMultiInputFrameSingleOutputFrames(self):
97+
self.TestStochasticModel(4, 1)
98+
99+
def testStochasticModelMultiInputFrameMultiOutputFrames(self):
100+
self.TestStochasticModel(7, 5)
101+
102+
103+
if __name__ == "__main__":
104+
tf.test.main()

0 commit comments

Comments
 (0)