From 6ab9251f261fccfb529163925810b18b5c6f9135 Mon Sep 17 00:00:00 2001 From: srihari-humbarwadi Date: Thu, 31 Mar 2022 19:42:25 +0530 Subject: [PATCH] compute `top_k` loss per sample --- .../losses/panoptic_deeplab_losses.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/official/vision/beta/projects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py b/official/vision/beta/projects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py index 97898610f7d..0abd561531d 100644 --- a/official/vision/beta/projects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py +++ b/official/vision/beta/projects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py @@ -76,18 +76,38 @@ def __call__(self, logits, labels, sample_weight=None): if self._top_k_percent_pixels >= 1.0: loss = tf.reduce_sum(cross_entropy_loss) / normalizer else: - cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1]) - top_k_pixels = tf.cast( - self._top_k_percent_pixels * - tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32) - top_k_losses, _ = tf.math.top_k( - cross_entropy_loss, k=top_k_pixels, sorted=True) - normalizer = tf.reduce_sum( - tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON - loss = tf.reduce_sum(top_k_losses) / normalizer - + loss = self._compute_top_k_loss(cross_entropy_loss) return loss + def _compute_top_k_loss(self, loss): + batch_size = tf.shape(loss)[0] + loss = tf.reshape(loss, shape=[batch_size, -1]) + + top_k_pixels = tf.cast( + self._top_k_percent_pixels * + tf.cast(tf.shape(loss)[-1], dtype=tf.float32), + dtype=tf.int32) + + # shape: [batch_size, top_k_pixels] + per_sample_top_k_loss = tf.map_fn( + fn=lambda x: tf.nn.top_k(x, k=top_k_pixels, sorted=False)[0], + elems=loss, + parallel_iterations=32, + fn_output_signature=tf.float32) + + # shape: [batch_size] + per_sample_normalizer = tf.reduce_sum( + tf.cast( + tf.not_equal(per_sample_top_k_loss, 0.0), + dtype=tf.float32), + axis=-1) + EPSILON + per_sample_normalized_loss = tf.reduce_sum( + per_sample_top_k_loss, axis=-1) / per_sample_normalizer + + normalized_loss = tf_utils.safe_mean(per_sample_normalized_loss) + return normalized_loss + + class CenterHeatmapLoss: def __init__(self): self._loss_fn = tf.losses.mean_squared_error