rag token nq

facebook

Introduction

The RAG-Token Model, based on the paper "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks" by Patrick Lewis et al., is designed for generating answers to knowledge-intensive questions. It operates on an uncased basis, converting all letters to lowercase. The model incorporates components such as a question encoder, a retriever, and a generator, all fine-tuned on the wiki_dpr dataset.

Architecture

The RAG-Token Model includes:

  • Question Encoder: Based on facebook/dpr-question_encoder-single-nq-base.
  • Retriever: Extracts relevant passages using the wiki_dpr dataset.
  • Generator: Built on facebook/bart-large, responsible for answer generation. These components are jointly fine-tuned end-to-end on the QA dataset.

Training

The model is fine-tuned using the wiki_dpr dataset. The retriever and question encoder were trained to work collaboratively to extract and generate relevant information from a vast knowledge base.

Guide: Running Locally

To run the model locally, follow these steps:

  1. Install Transformers Library:
    Ensure you have the transformers library installed.

    pip install transformers
    
  2. Load Model Components:
    Use the RagTokenizer, RagRetriever, and RagTokenForGeneration classes to load the model.

    from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
    
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
    retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
    model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
    
  3. Prepare Input and Generate Output:
    Prepare the input sequence and generate responses.

    input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt")
    generated = model.generate(input_ids=input_dict["input_ids"])
    print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])
    
  4. Hardware Requirements:
    The full legacy index requires over 75 GB of RAM. Consider using cloud-based solutions like AWS or Google Cloud with GPU support for efficient processing.

License

The RAG-Token Model is distributed under the Apache-2.0 license.

More Related APIs