Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

birefnet background removal - add batch (directory) processing (#2489) #1

Merged
merged 1 commit into from
Dec 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions extensions-builtin/forge_space_birefnet/forge_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import spaces
import os
import gradio as gr
Expand All @@ -7,9 +8,10 @@
import torch
from torchvision import transforms

# torch.set_float32_matmul_precision(["high", "highest"][0])
import glob
import pathlib
from PIL import Image

os.environ['HOME'] = spaces.convert_root_path() + 'home'

with spaces.capture_gpu_object() as birefnet_gpu_obj:
birefnet = AutoModelForImageSegmentation.from_pretrained(
Expand Down Expand Up @@ -44,11 +46,69 @@ def fn(image):
image.putalpha(mask)
return (image, origin)

@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def batch_process(input_folder, output_folder, save_png, save_flat):
# Ensure output folder exists
os.makedirs(output_folder, exist_ok=True)

# Supported image extensions
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']

# Collect all image files from input folder
input_images = []
for ext in image_extensions:
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}')))

# Process each image
processed_images = []
for image_path in input_images:
try:
# Load image
im = load_img(image_path, output_type="pil")
im = im.convert("RGB")
image_size = im.size
image = load_img(im)

# Prepare image for processing
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu)

# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()

pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)

# Apply mask
image.putalpha(mask)

# Save processed image
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}")

if save_flat:
background = Image.new('RGBA', image.size, (255, 255, 255))
image = Image.alpha_composite(background, image)
image = image.convert("RGB")
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"):
# jpegs don't support alpha channel, so add .png extension (not change, to avoid potential overwrites)
output_filename += ".png"
if save_png and not output_filename.lower().endswith(".png"):
output_filename += ".png"

image.save(output_filename)

processed_images.append(output_filename)

except Exception as e:
print(f"Error processing {image_path}: {str(e)}")

return processed_images

slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1)


chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
Expand All @@ -58,11 +118,27 @@ def fn(image):
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
)

tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
tab2 = gr.Interface(
fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never"
)

tab3 = gr.Interface(
batch_process,
inputs=[
gr.Textbox(label="Input folder path", max_lines=1),
gr.Textbox(label="Output folder path (will overwrite)", max_lines=1),
gr.Checkbox(label="Always save as PNG", value=True),
gr.Checkbox(label="Save flat (no mask)", value=False)
],
outputs=gr.File(label="Processed images", type="filepath", file_count="multiple"),
api_name="batch",
allow_flagging="never"
)

demo = gr.TabbedInterface(
[tab1, tab2], ["image", "text"], title="birefnet for background removal"
[tab1, tab2, tab3],
["image", "URL", "batch"],
title="birefnet for background removal"
)

if __name__ == "__main__":
Expand Down