MEDCLIP: Fine-Tuning a CLIP Model on the ROCO Medical Dataset

Introduction

This repository provides the code for fine-tuning a CLIP model using the ROCO dataset, which consists of radiology images and captions. This project was part of the Flax/Jax community week organized by Hugging Face and Google.

Architecture

The model employs the FlaxHybridCLIP architecture, leveraging the Flax/JAX frameworks for training on cloud TPU-v3-8. The pretrained model can be accessed from the Hugging Face Hub.

Training

The training utilized the ROCO dataset, comprising 57,780 images for training, 7,200 for validation, and 7,650 for testing. Images with captions shorter than 10 characters were excluded. The training process was conducted using the run_medclip.sh script, with the validation loss curve available for review.

Guide: Running Locally

  1. Clone the Transformers Repository: Ensure you are on the master branch.
  2. Install Dependencies: Use a virtual environment and run pip install -e ".[flax]".
  3. Load the Model:
    from medclip.modeling_hybrid_clip import FlaxHybridCLIP
    model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
    
  4. Utilize Cloud GPUs: For efficient training, consider using cloud-based TPUs or GPUs.

License

The content and code in this repository are subject to the terms and conditions outlined in the LICENSE file provided within the repository.

More Related APIs