37
37
38
38
class TransformerTest (tf .test .TestCase ):
39
39
40
- def getModel (self , hparams , mode = tf .estimator .ModeKeys .TRAIN , has_input = True ):
40
+ def getModel (self , hparams = None , mode = tf .estimator .ModeKeys .TRAIN ,
41
+ has_input = True , model_cls = transformer .Transformer ):
42
+ if hparams is None :
43
+ hparams = transformer .transformer_tiny ()
41
44
hparams .hidden_size = 8
42
45
hparams .filter_size = 32
43
46
hparams .num_heads = 1
@@ -58,7 +61,7 @@ def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, has_input=True):
58
61
"target_space_id" : tf .constant (1 , dtype = tf .int32 )
59
62
}
60
63
61
- return transformer . Transformer (hparams , mode , p_hparams ), features
64
+ return model_cls (hparams , mode , p_hparams ), features
62
65
63
66
def testTransformer (self ):
64
67
model , features = self .getModel (transformer .transformer_small ())
@@ -240,5 +243,47 @@ def testTransformerWithEncoderDecoderAttentionLoss(self):
240
243
self .assertEqual (res .shape , ())
241
244
242
245
246
+ class TransformerScorerTest (TransformerTest ):
247
+
248
+ def testReturnsScores (self ):
249
+ model , features = self .getModel (
250
+ mode = tf .estimator .ModeKeys .PREDICT ,
251
+ model_cls = transformer .TransformerScorer )
252
+ infer_out = model .infer (features )
253
+ self .assertTrue ("outputs" in infer_out )
254
+ self .assertTrue ("scores" in infer_out )
255
+
256
+ with self .test_session () as session :
257
+ session .run (tf .global_variables_initializer ())
258
+ infer_out = session .run (infer_out )
259
+ self .assertEqual ((BATCH_SIZE ,), infer_out ["scores" ].shape )
260
+ self .assertEqual ((BATCH_SIZE , TARGET_LENGTH ), infer_out ["outputs" ].shape )
261
+
262
+ def testVarNames (self ):
263
+ with tf .Graph ().as_default ():
264
+ model , features = self .getModel (
265
+ mode = tf .estimator .ModeKeys .PREDICT ,
266
+ model_cls = transformer .TransformerScorer )
267
+ _ = model .infer (features )
268
+ scorer_vars = [v .name for v in tf .global_variables ()]
269
+
270
+ with tf .Graph ().as_default ():
271
+ model , features = self .getModel (
272
+ mode = tf .estimator .ModeKeys .EVAL ,
273
+ model_cls = transformer .TransformerScorer )
274
+ _ = model (features )
275
+ scorer_eval_vars = [v .name for v in tf .global_variables ()]
276
+
277
+ with tf .Graph ().as_default ():
278
+ model , features = self .getModel (
279
+ mode = tf .estimator .ModeKeys .EVAL ,
280
+ model_cls = transformer .Transformer )
281
+ _ = model (features )
282
+ transformer_vars = [v .name for v in tf .global_variables ()]
283
+
284
+ self .assertEqual (sorted (scorer_vars ), sorted (transformer_vars ))
285
+ self .assertEqual (sorted (scorer_eval_vars ), sorted (transformer_vars ))
286
+
287
+
243
288
if __name__ == "__main__" :
244
289
tf .test .main ()
0 commit comments