pubmedbert base embeddings 8 M

NeuML

Introduction

The PubMedBERT Base Embeddings 8M model is a distilled version of PubMedBERT, leveraging the Model2Vec library to offer static embeddings. This model is optimized for environments with limited computational resources or where real-time performance is essential. It is particularly useful for semantic search and retrieval augmented generation (RAG) tasks.

Architecture

This model utilizes static embeddings, allowing for significantly faster computation of text embeddings on both GPU and CPU. The architecture is designed to enhance performance in terms of speed while maintaining a competitive level of accuracy.

Training

Training was conducted using the Tokenlearn library. The process involved featurizing the data with a script and then training the model, with BM25 weighting applied to improve accuracy. The model employs PCA for dimensionality reduction and normalizes embeddings post-training.

Guide: Running Locally

  1. Setup Environment: Ensure you have Python installed along with necessary libraries such as txtai, sentence-transformers, and model2vec.

  2. Using txtai:

    import txtai
    
    embeddings = txtai.Embeddings(path="neuml/pubmedbert-base-embeddings-8M", content=True)
    embeddings.index(documents())
    embeddings.search("query to run")
    
  3. Using Sentence-Transformers:

    from sentence_transformers import SentenceTransformer
    from sentence_transformers.models import StaticEmbedding
    
    static = StaticEmbedding.from_model2vec("neuml/pubmedbert-base-embeddings-8M")
    model = SentenceTransformer(modules=[static])
    
    sentences = ["This is an example sentence", "Each sentence is converted"]
    embeddings = model.encode(sentences)
    print(embeddings)
    
  4. Using Model2Vec:

    from model2vec import StaticModel
    
    model = StaticModel.from_pretrained("neuml/pubmedbert-base-embeddings-8M")
    sentences = ["This is an example sentence", "Each sentence is converted"]
    embeddings = model.encode(sentences)
    print(embeddings)
    
  5. Cloud GPUs: For enhanced performance, consider using cloud-based GPUs like NVIDIA RTX 3090 for indexing and running queries.

License

This model is licensed under the Apache-2.0 License, allowing for broad usage and modification with proper attribution.

More Related APIs in Sentence Similarity