bigbird roberta base

google

Introduction

BigBird is a sparse-attention-based transformer model designed to handle longer sequences more efficiently than traditional transformers like BERT. It extends the Transformer architecture, enabling significant improvements on tasks involving long sequences such as document summarization and question-answering. The model was introduced in the paper "Big Bird: Transformers for Longer Sequences" and is pretrained on the English language using a masked language modeling (MLM) objective.

Architecture

BigBird uses block sparse attention instead of the standard attention mechanism found in models like BERT. This approach allows it to manage sequences up to 4096 tokens in length with reduced computational costs. The model is particularly effective for tasks requiring processing of lengthy documents.

Training

The model was pretrained on four datasets: BookCorpus, Wikipedia, CC-News, and Stories. It employs the same sentencepiece vocabulary as RoBERTa, which is derived from GPT-2. The training involves splitting documents longer than 4096 tokens and joining shorter ones. Following BERT's training methodology, 15% of tokens are masked, and the model is trained to predict these masks. The training process is initialized from a RoBERTa checkpoint.

Guide: Running Locally

To use BigBird in PyTorch, follow these steps:

  1. Install the Transformers library:

    pip install transformers
    
  2. Load the BigBird model:

    from transformers import BigBirdModel, BigBirdTokenizer
    
    model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")
    tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
    
  3. Encode input text and obtain model output:

    text = "Replace me by any text you'd like."
    encoded_input = tokenizer(text, return_tensors='pt')
    output = model(**encoded_input)
    

For enhanced performance, consider using cloud-based GPUs such as those available from AWS, Google Cloud, or Azure.

License

BigBird is licensed under the Apache License 2.0.

More Related APIs