14
14
# limitations under the License.
15
15
# ******************************************************************************
16
16
import tensorflow as tf
17
- from tensorflow import convert_to_tensor , keras
18
17
19
18
20
- class CRF (keras .layers .Layer ):
19
+ class CRF (tf . keras .layers .Layer ):
21
20
"""
22
21
Conditional Random Field layer (tf.keras)
23
22
`CRF` can be used as the last layer in a network (as a classifier). Input shape (features)
@@ -29,55 +28,36 @@ class CRF(keras.layers.Layer):
29
28
30
29
Args:
31
30
num_labels (int): the number of labels to tag each temporal input.
32
- mode (string, optional): operation mode, 'reg' for regular full sequence learning (all
33
- sequences have equal length), or 'pad' for using with supplied sequence lengths (useful
34
- for padded sequences)
35
31
36
32
Input shape:
37
- 'reg' mode - nD tensor with shape `(batch_size, sentence length, num_classes)`.
38
- 'pad' mode - tuple of `(batch_size, sentence length, num_classes)`, `(batch_size, 1)`
33
+ nD tensor with shape `(batch_size, sentence length, num_classes)`.
39
34
40
35
Output shape:
41
36
nD tensor with shape: `(batch_size, sentence length, num_classes)`.
42
37
"""
43
- def __init__ (self , num_classes , mode = 'reg' , ** kwargs ):
38
+
39
+ def __init__ (self , num_classes , ** kwargs ):
44
40
self .transitions = None
45
41
super (CRF , self ).__init__ (** kwargs )
46
42
# num of output labels
47
43
self .output_dim = int (num_classes )
48
- self .mode = mode
49
- if self .mode == 'pad' :
50
- self .input_spec = [keras .layers .InputSpec (min_ndim = 3 ),
51
- keras .layers .InputSpec (min_ndim = 2 )]
52
- elif self .mode == 'reg' :
53
- self .input_spec = keras .layers .InputSpec (min_ndim = 3 )
54
- else :
55
- raise ValueError
56
- self .supports_masking = True
44
+ self .input_spec = tf .keras .layers .InputSpec (min_ndim = 3 )
45
+ self .supports_masking = False
57
46
self .sequence_lengths = None
58
47
59
48
def get_config (self ):
60
49
config = {
61
50
'output_dim' : self .output_dim ,
62
- 'mode' : self .mode ,
63
51
'supports_masking' : self .supports_masking ,
64
52
'transitions' : tf .keras .backend .eval (self .transitions )
65
53
}
66
54
base_config = super (CRF , self ).get_config ()
67
55
return dict (list (base_config .items ()) + list (config .items ()))
68
56
69
57
def build (self , input_shape ):
70
- if self .mode == 'pad' :
71
- assert len (input_shape ) == 2
72
- assert len (input_shape [0 ]) == 3
73
- assert len (input_shape [1 ]) == 2
74
- f_shape = tf .TensorShape (input_shape [0 ])
75
- input_spec = [keras .layers .InputSpec (min_ndim = 3 , axes = {- 1 : f_shape [- 1 ]}),
76
- keras .layers .InputSpec (min_ndim = 2 , axes = {- 1 : 1 }, dtype = tf .int32 )]
77
- else :
78
- assert len (input_shape ) == 3
79
- f_shape = tf .TensorShape (input_shape )
80
- input_spec = keras .layers .InputSpec (min_ndim = 3 , axes = {- 1 : f_shape [- 1 ]})
58
+ assert len (input_shape ) == 3
59
+ f_shape = tf .TensorShape (input_shape )
60
+ input_spec = tf .keras .layers .InputSpec (min_ndim = 3 , axes = {- 1 : f_shape [- 1 ]})
81
61
82
62
if f_shape [- 1 ] is None :
83
63
raise ValueError ('The last dimension of the inputs to `CRF` '
@@ -92,21 +72,26 @@ def build(self, input_shape):
92
72
trainable = True )
93
73
self .built = True
94
74
95
- def call (self , inputs , ** kwargs ):
96
- if self .mode == 'pad' :
97
- sequences = convert_to_tensor (inputs [0 ], dtype = self .dtype )
98
- self .sequence_lengths = tf .keras .backend .flatten (inputs [- 1 ])
75
+ # pylint: disable=arguments-differ
76
+ def call (self , inputs , sequence_lengths = None , ** kwargs ):
77
+ sequences = tf .convert_to_tensor (inputs , dtype = self .dtype )
78
+ if sequence_lengths is not None :
79
+ assert len (sequence_lengths .shape ) == 2
80
+ assert tf .convert_to_tensor (sequence_lengths ).dtype == 'int32'
81
+ seq_len_shape = tf .convert_to_tensor (sequence_lengths ).get_shape ().as_list ()
82
+ assert seq_len_shape [1 ] == 1
83
+ self .sequence_lengths = tf .keras .backend .flatten (sequence_lengths )
99
84
else :
100
- sequences = convert_to_tensor ( inputs , dtype = self . dtype )
101
- shape = tf .shape (inputs )
102
- self . sequence_lengths = tf . ones ( shape [ 0 ], dtype = tf . int32 ) * ( shape [ 1 ])
85
+ self . sequence_lengths = tf . ones ( tf . shape ( inputs )[ 0 ] , dtype = tf . int32 ) * \
86
+ ( tf .shape (inputs )[ 1 ] )
87
+
103
88
viterbi_sequence , _ = tf .contrib .crf .crf_decode (sequences , self .transitions ,
104
89
self .sequence_lengths )
105
- output = keras .backend .one_hot (viterbi_sequence , self .output_dim )
106
- return keras .backend .in_train_phase (sequences , output )
90
+ output = tf . keras .backend .one_hot (viterbi_sequence , self .output_dim )
91
+ return tf . keras .backend .in_train_phase (sequences , output )
107
92
108
93
def loss (self , y_true , y_pred ):
109
- y_pred = convert_to_tensor (y_pred , dtype = self .dtype )
94
+ y_pred = tf . convert_to_tensor (y_pred , dtype = self .dtype )
110
95
log_likelihood , self .transitions = \
111
96
tf .contrib .crf .crf_log_likelihood (y_pred ,
112
97
tf .cast (tf .keras .backend .argmax (y_true ),
@@ -116,12 +101,8 @@ def loss(self, y_true, y_pred):
116
101
return tf .reduce_mean (- log_likelihood )
117
102
118
103
def compute_output_shape (self , input_shape ):
119
- if self .mode == 'pad' :
120
- data_shape = input_shape [0 ]
121
- else :
122
- data_shape = input_shape
123
- tf .TensorShape (data_shape ).assert_has_rank (3 )
124
- return data_shape [:2 ] + (self .output_dim ,)
104
+ tf .TensorShape (input_shape ).assert_has_rank (3 )
105
+ return input_shape [:2 ] + (self .output_dim ,)
125
106
126
107
@property
127
108
def viterbi_accuracy (self ):
@@ -130,7 +111,7 @@ def accuracy(y_true, y_pred):
130
111
sequence_lengths = tf .ones (shape [0 ], dtype = tf .int32 ) * (shape [1 ])
131
112
viterbi_sequence , _ = tf .contrib .crf .crf_decode (y_pred , self .transitions ,
132
113
sequence_lengths )
133
- output = keras .backend .one_hot (viterbi_sequence , self .output_dim )
114
+ output = tf . keras .backend .one_hot (viterbi_sequence , self .output_dim )
134
115
return tf .keras .metrics .categorical_accuracy (y_true , output )
135
116
accuracy .func_name = 'viterbi_accuracy'
136
117
return accuracy
0 commit comments