Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565707394
  • Loading branch information
tensorflower-gardener committed Sep 15, 2023
1 parent b4294ca commit aeaf149
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions official/vision/ops/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,9 +1716,9 @@ def _apply_func_with_prob(func: Any, image: tf.Tensor,
return augmented_image, augmented_bboxes


def select_and_apply_random_policy(policies: Any,
image: tf.Tensor,
bboxes: Optional[tf.Tensor] = None):
def select_and_apply_random_policy(
policies: Any, image: tf.Tensor, bboxes: Optional[tf.Tensor] = None
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Select a random policy from `policies` and apply it to `image`."""
policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
# Note that using tf.case instead of tf.conds would result in significantly
Expand Down Expand Up @@ -2075,6 +2075,7 @@ def distort_with_boxes(self, image: tf.Tensor,
tf_policies = self._make_tf_policies()
image, bboxes = select_and_apply_random_policy(tf_policies, image, bboxes)
image = tf.cast(image, dtype=input_image_type)
assert bboxes is not None
return image, bboxes

@staticmethod
Expand Down Expand Up @@ -2493,6 +2494,7 @@ def distort_with_boxes(self, image: tf.Tensor,
bboxes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""See base class."""
image, bboxes = self._distort_common(image, bboxes)
assert bboxes is not None
return image, bboxes


Expand Down

0 comments on commit aeaf149

Please sign in to comment.