rag token nq
facebookIntroduction
The RAG-Token Model, based on the paper "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks" by Patrick Lewis et al., is designed for generating answers to knowledge-intensive questions. It operates on an uncased basis, converting all letters to lowercase. The model incorporates components such as a question encoder, a retriever, and a generator, all fine-tuned on the wiki_dpr dataset.
Architecture
The RAG-Token Model includes:
- Question Encoder: Based on
facebook/dpr-question_encoder-single-nq-base
. - Retriever: Extracts relevant passages using the
wiki_dpr
dataset. - Generator: Built on
facebook/bart-large
, responsible for answer generation. These components are jointly fine-tuned end-to-end on the QA dataset.
Training
The model is fine-tuned using the wiki_dpr
dataset. The retriever and question encoder were trained to work collaboratively to extract and generate relevant information from a vast knowledge base.
Guide: Running Locally
To run the model locally, follow these steps:
-
Install Transformers Library:
Ensure you have thetransformers
library installed.pip install transformers
-
Load Model Components:
Use theRagTokenizer
,RagRetriever
, andRagTokenForGeneration
classes to load the model.from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
-
Prepare Input and Generate Output:
Prepare the input sequence and generate responses.input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt") generated = model.generate(input_ids=input_dict["input_ids"]) print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])
-
Hardware Requirements:
The full legacy index requires over 75 GB of RAM. Consider using cloud-based solutions like AWS or Google Cloud with GPU support for efficient processing.
License
The RAG-Token Model is distributed under the Apache-2.0 license.