Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the Kendalls Tau metric when used in graph mode #2739

Merged
merged 4 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 69 additions & 77 deletions tensorflow_addons/metrics/kendalls_tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,22 @@ def __init__(
self.preds_max = preds_max
self.actual_cutpoints = actual_cutpoints
self.preds_cutpoints = preds_cutpoints
self.reset_state()
self.actual_cuts = tf.linspace(
tf.cast(self.actual_min, tf.float32),
tf.cast(self.actual_max, tf.float32),
self.actual_cutpoints - 1,
)
self.preds_cuts = tf.linspace(
tf.cast(self.preds_min, tf.float32),
tf.cast(self.preds_max, tf.float32),
self.preds_cutpoints - 1,
)
self.m = self.add_weight(
"m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64
)
self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64)
self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64)
self.n = self.add_weight("n", (), dtype=tf.int64)

def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates ranks.
Expand All @@ -89,75 +104,69 @@ def update_state(self, y_true, y_pred, sample_weight=None):
Returns:
Update op.
"""
if y_true.shape and y_true.shape[0]:
i = tf.searchsorted(
self.actual_cuts,
tf.cast(tf.reshape(y_true, -1), self.actual_cuts.dtype),
i = tf.searchsorted(
self.actual_cuts,
tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype),
)
j = tf.searchsorted(
self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype)
)

m = tf.sparse.from_dense(self.m)
nrow = tf.sparse.from_dense(self.nrow)
ncol = tf.sparse.from_dense(self.ncol)

k = 0
while k < tf.shape(i)[0]:
m = tf.sparse.add(
m,
tf.SparseTensor(
[[i[k], j[k]]],
tf.cast([1], dtype=m.dtype),
self.m.shape,
),
)
j = tf.searchsorted(
self.preds_cuts, tf.cast(tf.reshape(y_pred, -1), self.preds_cuts.dtype)
nrow = tf.sparse.add(
nrow,
tf.SparseTensor(
[[i[k]]],
tf.cast([1], dtype=nrow.dtype),
self.nrow.shape,
),
)

def body(k, n, m, nrow, ncol):
return (
k + 1,
n + 1,
tf.sparse.add(
m,
tf.SparseTensor(
[[i[k], j[k]]],
tf.cast([1], dtype=self.m.dtype),
self.m.shape,
),
),
tf.sparse.add(
nrow,
tf.SparseTensor(
[[i[k]]],
tf.cast([1], dtype=self.nrow.dtype),
self.nrow.shape,
),
),
tf.sparse.add(
ncol,
tf.SparseTensor(
[[j[k]]],
tf.cast([1], dtype=self.ncol.dtype),
self.ncol.shape,
),
),
)

_, self.n, self.m, self.nrow, self.ncol = tf.while_loop(
lambda k, n, m, nrow, ncol: k < i.shape[0],
body=body,
loop_vars=(0, self.n, self.m, self.nrow, self.ncol),
ncol = tf.sparse.add(
ncol,
tf.SparseTensor(
[[j[k]]],
tf.cast([1], dtype=ncol.dtype),
self.ncol.shape,
),
)
k += 1

self.n.assign_add(tf.cast(k, tf.int64))
self.m.assign(tf.sparse.to_dense(m))
self.nrow.assign(tf.sparse.to_dense(nrow))
self.ncol.assign(tf.sparse.to_dense(ncol))

def result(self):
m_dense = tf.sparse.to_dense(tf.cast(self.m, tf.float32))
n_cap = tf.cumsum(
tf.cumsum(
tf.slice(tf.pad(m_dense, [[1, 0], [1, 0]]), [0, 0], self.m.shape),
axis=0,
),
axis=1,
)
m = tf.cast(self.m, tf.float32)
n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1)
# Number of concordant pairs.
p = tf.math.reduce_sum(tf.multiply(n_cap, m_dense))
sum_m_squard = tf.math.reduce_sum(tf.math.square(m_dense))
p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:]))
sum_m_squard = tf.math.reduce_sum(tf.math.square(m))
# Ties in x.
t = (
tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.nrow)))
tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32)
- sum_m_squard
) / 2.0
# Ties in y.
u = (
tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.ncol)))
tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32)
- sum_m_squard
) / 2.0
# Ties in both.
b = tf.math.reduce_sum(tf.multiply(m_dense, (m_dense - 1.0))) / 2.0
b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0
# Number of discordant pairs.
n = tf.cast(self.n, tf.float32)
q = (n - 1.0) * n / 2.0 - p - t - u - b
Expand All @@ -179,28 +188,11 @@ def get_config(self):

def reset_state(self):
"""Resets all of the metric state variables."""
self.actual_cuts = tf.linspace(
tf.cast(self.actual_min, tf.float32),
tf.cast(self.actual_max, tf.float32),
self.actual_cutpoints - 1,
)
self.preds_cuts = tf.linspace(
tf.cast(self.preds_min, tf.float32),
tf.cast(self.preds_max, tf.float32),
self.preds_cutpoints - 1,
)
self.m = tf.SparseTensor(
tf.zeros((0, 2), tf.int64),
[],
[self.actual_cutpoints, self.preds_cutpoints],
)
self.nrow = tf.SparseTensor(
tf.zeros((0, 1), dtype=tf.int64), [], [self.actual_cutpoints]
)
self.ncol = tf.SparseTensor(
tf.zeros((0, 1), dtype=tf.int64), [], [self.preds_cutpoints]
)
self.n = 0

self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64))
self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64))
self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64))
self.n.assign(0)

def reset_states(self):
# Backwards compatibility alias of `reset_state`. New classes should
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/metrics/tests/kendalls_tau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def test_keras_binary_classification_model():
x = np.random.rand(1000, 10).astype(np.float32)
y = np.random.rand(1000, 1).astype(np.float32)

model.fit(x, y, epochs=1, verbose=0, batch_size=32)
history = model.fit(x, y, epochs=1, verbose=0, batch_size=32)
assert not any(np.isnan(history.history["kendalls_tau"]))


def test_kendalls_tau_serialization():
Expand Down