Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using SAM2 detects few trees #314

Open
dnromero opened this issue Sep 15, 2024 · 2 comments
Open

Using SAM2 detects few trees #314

dnromero opened this issue Sep 15, 2024 · 2 comments

Comments

@dnromero
Copy link

Hi everyone

I recently started evaluating SAM2 for tree detection.
I´d like to clarify that I am new to this whole topic and I´m trying to learn how to use SAM2 to detect trees.

I have tried the following code:
import os

if using Apple MPS, fall back to CPU for unsupported ops

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

select the device for computation

if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).enter()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. pytorch/pytorch#84936 for a discussion."
)
np.random.seed(3)

def show_anns(anns, borders=True):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:, :, 3] = 0
for ann in sorted_anns:
    m = ann['segmentation']
    color_mask = np.concatenate([np.random.random(3), [0.5]])
    img[m] = color_mask 
    if borders:
        import cv2
        contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

ax.imshow(img)

image = Image.open(r"C:\Users\Lenovo\Desktop\Daniel\AIconteo\cerro2_corte.tif")
image = np.array(image.convert("RGB"))
plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('off')
plt.show()
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = r"C:\Users\Lenovo\segment-anything-2\checkpoints\sam2_hiera_base_plus.pt"
model_cfg = "sam2_hiera_b+.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
121
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
example 1
example 2

I add some examples of images. I am already very grateful if can anyone help me? :)

@heyoeyo
Copy link

heyoeyo commented Sep 16, 2024

For this sort of example, SAM might not be a great option. With so many small, similar looking objects, it might be worth trying simpler methods like correlation (see template matching from scikit-image) or even just thresholding might be a good start (for this specific image at least).

However, you can probably improve the SAM result by adjusting the default settings. Two that seem like they might help are points_per_side and crop_n_layers. The auto-masking works by generating a bunch of single point prompts in a grid, and the points_per_side setting controls how many points are used in that grid. The default setting (32) is probably too low to account for so many objects, so it's worth increasing that (at least 64, maybe even higher?).
The crop_n_layers setting is almost like a 'zoom-in' feature, which runs segmentation on smaller cropped parts of the image, which should just help catch some of the smaller trees. This setting slows things down a lot, so see if a setting of 1 helps before setting it any higher.

If you want to see what these options are doing, you can add the following code just after line 266 of the mask generator:

# Visualize point prompts used by the mask generator
import cv2
debug_img = cv2.cvtColor(cropped_im, cv2.COLOR_RGB2BGR)
for xy in points_for_image:
    pt_xy = xy.astype(np.int32).tolist()
    cv2.circle(debug_img, pt_xy, 2, (255,0,255), -1)
cv2.imshow("DebugPoints", debug_img)
cv2.waitKey(0)
cv2.destroyWindow("DebugPoints")

This will create a pop-up image showing the 'crop' that's being used along with the point prompts (it will also pause the masking, but you can press any key with the window open to resume).

One last thing that's probably helpful if you only want the trees is to ignore any overly large masks. You can filter them out using something like:

# Filter out large masks
max_area = 1000
masks = [m for m in masks if m["area"] < max_area]

@HendricksJudy
Copy link

Small object detection is another CV domain.
You may need to find a better embedding before you use the SAM2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants