Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
Squadrick committed Jul 31, 2019
1 parent 7a0435b commit 205ebd1
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tensorflow_addons/text/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
# TODO: Wrap functions in @tf.function once
# https://github.com/tensorflow/tensorflow/issues/29075 is resolved


def crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params):
"""Computes the unnormalized score for a tag sequence.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the unnormalized score.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
we compute the unnormalized score.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix.
Returns:
Expand Down Expand Up @@ -171,15 +172,16 @@ def crf_log_likelihood(inputs,
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the log-likelihood.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
we compute the log-likelihood.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix, if available.
transition_params: A [num_tags, num_tags] transition matrix,
if available.
Returns:
log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
each example, given the sequence of tag indices.
transition_params: A [num_tags, num_tags] transition matrix. This is either
provided by the caller or created in this function.
transition_params: A [num_tags, num_tags] transition matrix. This is
either provided by the caller or created in this function.
"""
# Get shape information.
num_tags = inputs.shape[2]
Expand Down Expand Up @@ -252,7 +254,8 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions])

# Encode the indices in a flattened representation.
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
flattened_transition_indices = start_tag_indices * \
num_tags + end_tag_indices
flattened_transition_params = tf.reshape(transition_params, [-1])

# Get the binary scores based on the flattened representation.
Expand Down Expand Up @@ -281,7 +284,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths):
sequence_lengths: A [batch_size] vector of true sequence lengths.
Returns:
new_alphas: A [batch_size, num_tags] matrix containing the
new_alphas: A [batch_size, num_tags] matrix containing the
new alpha values.
"""

Expand Down

0 comments on commit 205ebd1

Please sign in to comment.