segformer b3 fashion
sayeed99Introduction
The SEGFORMER-B3-FASHION
model is a fine-tuned version of nvidia/mit-b3
designed for fashion image segmentation. It utilizes the sayeed99/fashion_segmentation
dataset to enhance its capabilities in segmenting fashion-related images without resizing the original image sizes.
Architecture
The model is based on the SegFormer architecture, which combines the efficiency of transformers with semantic segmentation tasks. It is implemented using the PyTorch framework and leverages the Transformers library for processing and fine-tuning. The model outputs segmentation maps with 47 different labels, ranging from clothing items like shirts and dresses to accessories such as glasses and hats.
Training
The model was trained using the sayeed99/fashion_segmentation
dataset. The training process employed the original image sizes to maintain data integrity and improve segmentation accuracy. The training setup includes:
- Transformers Version: 4.30.0
- PyTorch Version: 2.2.2+cu121
- Datasets Version: 2.18.0
- Tokenizers Version: 0.13.3
Guide: Running Locally
To run the SEGFORMER-B3-FASHION
model locally, follow these steps:
-
Install Required Libraries:
pip install transformers==4.30.0 torch==2.2.2+cu121 datasets==2.18.0 tokenizers==0.13.3
-
Load the Model and Processor:
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation from PIL import Image import requests import torch.nn as nn processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion") model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
-
Process an Image:
url = "https://example.com/image.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1)[0]
-
Visualize the Segmentation:
import matplotlib.pyplot as plt plt.imshow(pred_seg)
To accelerate processing, consider using cloud GPUs such as those offered by AWS, Google Cloud, or Azure.
License
The model is licensed under a specific license, which can be reviewed at SegFormer License.