Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF SAM shape flexibility fixes #23842

Merged
merged 1 commit into from
May 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/transformers/models/sam/modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> t
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape(
hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head)
hidden_states,
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
)

def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
Expand Down Expand Up @@ -509,7 +510,7 @@ def call(
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
if sparse_prompt_embeddings.shape[1] != 0:
if shape_list(sparse_prompt_embeddings)[1] != 0:
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
else:
tokens = output_tokens
Expand Down Expand Up @@ -695,8 +696,8 @@ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.T
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1)
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2)
Expand All @@ -722,12 +723,12 @@ def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.T
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = boxes.shape[:2]
batch_size, nb_boxes = shape_list(boxes)[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where(
tf.range(corner_embedding.shape[2])[None, None, :, None] == 0,
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
self.point_embed[2][0],
self.point_embed[3][0],
)
Expand All @@ -754,7 +755,7 @@ def call(
"""
sparse_embeddings = None
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
batch_size, point_batch_size = shape_list(input_points)[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
Expand All @@ -763,7 +764,7 @@ def call(
)
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None:
batch_size = input_boxes.shape[0]
batch_size = shape_list(input_boxes)[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
Expand Down Expand Up @@ -1376,8 +1377,8 @@ def call(
" got {}.".format(input_boxes.shape),
)
if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1]
box_batch_size = input_boxes.shape[1]
point_batch_size = shape_list(input_points)[1]
box_batch_size = shape_list(input_boxes)[1]
if point_batch_size != box_batch_size:
raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
Expand Down