Skip to content

Commit

Permalink
SAM e2e test tolerance explained
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Mar 22, 2024
1 parent 364e196 commit 2763db9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
41 changes: 39 additions & 2 deletions tests/foundationals/segment_anything/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer

# See predictor_example.ipynb official notebook
Expand Down Expand Up @@ -409,7 +409,7 @@ def test_predictor_single_output(
assert torch.allclose(
low_res_masks[0, 0, ...],
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
atol=6e-3, # see test_predictor_resized_single_output for more explanation
)
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)

Expand All @@ -418,6 +418,43 @@ def test_predictor_single_output(
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)


def test_predictor_resized_single_output(
facebook_sam_h_predictor: FacebookSAMPredictor,
sam_h_single_output: SegmentAnythingH,
truck: Image.Image,
one_prompt: SAMPrompt,
) -> None:
# The refiners implementation of SAM differs from official
# implementation by a 6e-3 absolute diff (see test_predictor_single_output)
# This diff is related to 2 components :
# * image_encoder (see test_image_encoder)
# * point rescaling (facebook uses numpy while refiners uses torch)
#
# Current test is designed to workaround those 2 components
# * facebook image_embedding is used
# * the image is pre-resized by (1024, 1024) so there is no rescaling
# Then the test pass with torch.equal

predictor = facebook_sam_h_predictor
size = (1024, 1024)
resized_truck = truck.resize(size)
predictor.set_image(np.array(resized_truck))

_, _, facebook_low_res_masks = predictor.predict( # type: ignore
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=False,
)

facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size)

_, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__)

assert torch.equal(
low_res_masks[0, 0, ...],
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
)


def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/foundationals/segment_anything/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def device(self) -> Any: ...

class FacebookSAMPredictor:
model: FacebookSAM
features: Tensor

def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...

Expand Down

0 comments on commit 2763db9

Please sign in to comment.