BiRefNet: Bilateral Reference for High-Resolution Dichotomous Image Segmentation

Introduction

BiRefNet is an advanced model for high-resolution dichotomous image segmentation, excelling in tasks such as camouflaged and salient object detection. This repository contains the official implementation and weights of the BiRefNet model as described in the paper "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (CAAI AIR 2024).

Architecture

BiRefNet is designed for high-resolution image segmentation tasks. It utilizes a bilateral reference mechanism to enhance segmentation accuracy, particularly in dichotomous scenarios. The model supports various applications, including background removal and mask generation.

Training

BiRefNet was trained on the DIS-TR dataset and validated on DIS-TEs and DIS-VD datasets, achieving state-of-the-art performance in dichotomous image segmentation (DIS), high-resolution salient object detection (HRSOD), and camouflaged object detection (COD). Training relied on GPU resources generously provided by collaborators.

Guide: Running Locally

Basic Steps

  1. Install Packages
    Install the required packages using:

    pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
    
  2. Load BiRefNet Using Hugging Face
    Load the model with pre-trained weights:

    from transformers import AutoModelForImageSegmentation
    birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
    
  3. Load from GitHub
    Clone the repository for the latest code and use Hugging Face for weights:

    git clone https://github.com/ZhengPeng7/BiRefNet.git
    cd BiRefNet
    
    from models.birefnet import BiRefNet
    birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
    
  4. Local Weights and Code
    Use both code and weights locally:

    import torch
    from utils import check_state_dict
    
    birefnet = BiRefNet(bb_pretrained=False)
    state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
    state_dict = check_state_dict(state_dict)
    birefnet.load_state_dict(state_dict)
    
  5. Inference
    Perform inference with BiRefNet:

    from PIL import Image
    import matplotlib.pyplot as plt
    from torchvision import transforms
    
    def extract_object(birefnet, imagepath):
        image_size = (1024, 1024)
        transform_image = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        image = Image.open(imagepath)
        input_images = transform_image(image).unsqueeze(0).to('cuda')
    
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image.size)
        image.putalpha(mask)
        return image, mask
    
    plt.axis("off")
    plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
    plt.show()
    

Cloud GPUs

For enhanced performance, consider using cloud-based GPUs such as AWS, GCP, or Azure. This setup will facilitate faster training and inference times, especially for high-resolution images.

License

BiRefNet is licensed under the MIT License, allowing for widespread use and modification while ensuring attribution to the original authors.

More Related APIs in Image Segmentation