conditional detr resnet 50

microsoft

Introduction

Conditional DETR is an object detection model utilizing a ResNet-50 backbone, introduced to address slow training convergence in DETR models. It was trained on the COCO 2017 dataset, comprising 118k annotated images. The model employs a conditional cross-attention mechanism to enhance training speed.

Architecture

The model applies a transformer encoder-decoder architecture with a conditional cross-attention mechanism. This method improves upon the standard DETR by using conditional spatial queries that focus on specific object regions, facilitating faster convergence and reducing reliance on comprehensive content embeddings.

Training

The model was trained on the COCO 2017 dataset, which includes 118k images for training and 5k for validation. The innovative training approach allows the model to converge up to 10 times faster than traditional DETR models, depending on the backbone used.

Guide: Running Locally

To run the Conditional DETR model locally, follow these steps:

  1. Install Dependencies: Ensure you have transformers and torch installed.
  2. Load Model and Processor: Use AutoImageProcessor and ConditionalDetrForObjectDetection from the transformers library.
  3. Prepare Input Image: Fetch and prepare the image using PIL and requests.
  4. Process and Predict: Use the processor to prepare inputs and the model to make predictions.
from transformers import AutoImageProcessor, ConditionalDetrForObjectDetection
import torch
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
model = ConditionalDetrForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50")

inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)

target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
        f"Detected {model.config.id2label[label.item()]} with confidence "
        f"{round(score.item(), 3)} at location {box}"
    )

Cloud GPUs such as those available from AWS, Google Cloud, or Azure can be used to speed up computations.

License

The model is released under the Apache-2.0 License, allowing free use with proper attribution.

More Related APIs in Object Detection