medclip
kaushalyaMEDCLIP: Fine-Tuning a CLIP Model on the ROCO Medical Dataset
Introduction
This repository provides the code for fine-tuning a CLIP model using the ROCO dataset, which consists of radiology images and captions. This project was part of the Flax/Jax community week organized by Hugging Face and Google.
Architecture
The model employs the FlaxHybridCLIP architecture, leveraging the Flax/JAX frameworks for training on cloud TPU-v3-8. The pretrained model can be accessed from the Hugging Face Hub.
Training
The training utilized the ROCO dataset, comprising 57,780 images for training, 7,200 for validation, and 7,650 for testing. Images with captions shorter than 10 characters were excluded. The training process was conducted using the run_medclip.sh
script, with the validation loss curve available for review.
Guide: Running Locally
- Clone the Transformers Repository: Ensure you are on the master branch.
- Install Dependencies: Use a virtual environment and run
pip install -e ".[flax]"
. - Load the Model:
from medclip.modeling_hybrid_clip import FlaxHybridCLIP model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
- Utilize Cloud GPUs: For efficient training, consider using cloud-based TPUs or GPUs.
License
The content and code in this repository are subject to the terms and conditions outlined in the LICENSE file provided within the repository.