segformer b3 fashion

sayeed99

Introduction

The SEGFORMER-B3-FASHION model is a fine-tuned version of nvidia/mit-b3 designed for fashion image segmentation. It utilizes the sayeed99/fashion_segmentation dataset to enhance its capabilities in segmenting fashion-related images without resizing the original image sizes.

Architecture

The model is based on the SegFormer architecture, which combines the efficiency of transformers with semantic segmentation tasks. It is implemented using the PyTorch framework and leverages the Transformers library for processing and fine-tuning. The model outputs segmentation maps with 47 different labels, ranging from clothing items like shirts and dresses to accessories such as glasses and hats.

Training

The model was trained using the sayeed99/fashion_segmentation dataset. The training process employed the original image sizes to maintain data integrity and improve segmentation accuracy. The training setup includes:

  • Transformers Version: 4.30.0
  • PyTorch Version: 2.2.2+cu121
  • Datasets Version: 2.18.0
  • Tokenizers Version: 0.13.3

Guide: Running Locally

To run the SEGFORMER-B3-FASHION model locally, follow these steps:

  1. Install Required Libraries:

    pip install transformers==4.30.0 torch==2.2.2+cu121 datasets==2.18.0 tokenizers==0.13.3
    
  2. Load the Model and Processor:

    from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
    from PIL import Image
    import requests
    import torch.nn as nn
    
    processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
    model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
    
  3. Process an Image:

    url = "https://example.com/image.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    inputs = processor(images=image, return_tensors="pt")
    
    outputs = model(**inputs)
    logits = outputs.logits.cpu()
    
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    
    pred_seg = upsampled_logits.argmax(dim=1)[0]
    
  4. Visualize the Segmentation:

    import matplotlib.pyplot as plt
    plt.imshow(pred_seg)
    

To accelerate processing, consider using cloud GPUs such as those offered by AWS, Google Cloud, or Azure.

License

The model is licensed under a specific license, which can be reviewed at SegFormer License.

More Related APIs in Image Segmentation