switch c 2048

google

Introduction

Switch Transformers is a Mixture of Experts (MoE) model trained on the Masked Language Modeling (MLM) task. The architecture is similar to T5, but replaces Feed Forward layers with Sparse MLP layers containing "experts" MLP. This allows for faster training and improved performance on fine-tuned tasks. The model was developed to scale up to trillion-parameter models and achieves a 4x speedup over the T5-XXL model.

Architecture

  • Model Type: Language model
  • Languages: English
  • License: Apache 2.0
  • Related Models: All FLAN-T5 Checkpoints
  • Resources: Research paper and GitHub repository for further information

Training

The model was trained using the Colossal Clean Crawled Corpus (C4) dataset for Masked Language Modeling. Training was conducted on TPU v3 or v4 pods using the t5x codebase along with JAX. The model's performance has been compared to T5 across various tasks, with detailed results available in the research paper.

Guide: Running Locally

Basic Steps

  1. Install Dependencies:

    pip install accelerate
    
  2. Load the Model:

    from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
    
    tokenizer = AutoTokenizer.from_pretrained("google/switch-c-2048")
    model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-c-2048", device_map="auto", offload_folder=<OFFLOAD_FOLDER>)
    
  3. Prepare Input:

    input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
  4. Generate Output:

    outputs = model.generate(input_ids)
    print(tokenizer.decode(outputs[0]))
    

Cloud GPUs

Due to the model's large size, it is advisable to run it on cloud GPUs such as those provided by AWS, Google Cloud, or Azure. Ensure that the environment supports disk offloading to manage memory efficiently.

License

The model is licensed under the Apache 2.0 License.

More Related APIs in Text2text Generation