Bi Ref Net
ZhengPeng7BiRefNet: Bilateral Reference for High-Resolution Dichotomous Image Segmentation
Introduction
BiRefNet is an advanced model for high-resolution dichotomous image segmentation, excelling in tasks such as camouflaged and salient object detection. This repository contains the official implementation and weights of the BiRefNet model as described in the paper "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (CAAI AIR 2024).
Architecture
BiRefNet is designed for high-resolution image segmentation tasks. It utilizes a bilateral reference mechanism to enhance segmentation accuracy, particularly in dichotomous scenarios. The model supports various applications, including background removal and mask generation.
Training
BiRefNet was trained on the DIS-TR dataset and validated on DIS-TEs and DIS-VD datasets, achieving state-of-the-art performance in dichotomous image segmentation (DIS), high-resolution salient object detection (HRSOD), and camouflaged object detection (COD). Training relied on GPU resources generously provided by collaborators.
Guide: Running Locally
Basic Steps
-
Install Packages
Install the required packages using:pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
-
Load BiRefNet Using Hugging Face
Load the model with pre-trained weights:from transformers import AutoModelForImageSegmentation birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
-
Load from GitHub
Clone the repository for the latest code and use Hugging Face for weights:git clone https://github.com/ZhengPeng7/BiRefNet.git cd BiRefNet
from models.birefnet import BiRefNet birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
-
Local Weights and Code
Use both code and weights locally:import torch from utils import check_state_dict birefnet = BiRefNet(bb_pretrained=False) state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu') state_dict = check_state_dict(state_dict) birefnet.load_state_dict(state_dict)
-
Inference
Perform inference with BiRefNet:from PIL import Image import matplotlib.pyplot as plt from torchvision import transforms def extract_object(birefnet, imagepath): image_size = (1024, 1024) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = Image.open(imagepath) input_images = transform_image(image).unsqueeze(0).to('cuda') with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) image.putalpha(mask) return image, mask plt.axis("off") plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0]) plt.show()
Cloud GPUs
For enhanced performance, consider using cloud-based GPUs such as AWS, GCP, or Azure. This setup will facilitate faster training and inference times, especially for high-resolution images.
License
BiRefNet is licensed under the MIT License, allowing for widespread use and modification while ensuring attribution to the original authors.