diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index a3d5928960..d8d97bf216 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -23,6 +23,7 @@ # TODO: Wrap functions in @tf.function once # https://github.com/tensorflow/tensorflow/issues/29075 is resolved + def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -30,8 +31,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. - tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we - compute the unnormalized score. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: @@ -171,15 +172,16 @@ def crf_log_likelihood(inputs, Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. - tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we - compute the log-likelihood. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the log-likelihood. sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] transition matrix, if available. + transition_params: A [num_tags, num_tags] transition matrix, + if available. Returns: log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of each example, given the sequence of tag indices. - transition_params: A [num_tags, num_tags] transition matrix. This is either - provided by the caller or created in this function. + transition_params: A [num_tags, num_tags] transition matrix. This is + either provided by the caller or created in this function. """ # Get shape information. num_tags = inputs.shape[2] @@ -252,7 +254,8 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) # Encode the indices in a flattened representation. - flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices + flattened_transition_indices = start_tag_indices * \ + num_tags + end_tag_indices flattened_transition_params = tf.reshape(transition_params, [-1]) # Get the binary scores based on the flattened representation. @@ -281,7 +284,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths): sequence_lengths: A [batch_size] vector of true sequence lengths. Returns: - new_alphas: A [batch_size, num_tags] matrix containing the + new_alphas: A [batch_size, num_tags] matrix containing the new alpha values. """