switch base 8

google

Introduction

Switch Transformers is a Mixture of Experts (MoE) model that applies a novel architecture for efficient language modeling. It is similar to the T5 model but uses Sparse MLP "experts" in its Feed Forward layers. This design aims to improve training speed and model performance on fine-tuned tasks.

Architecture

  • Model Type: Language model
  • Language: English
  • Key Components: Sparse MLP layers with "experts"
  • License: Apache 2.0
  • Related Models: Other Switch Transformers Checkpoints available
  • Resources: Research paper, GitHub repository, and Hugging Face documentation

Training

  • Data: Trained on the Colossal Clean Crawled Corpus (C4) dataset using Masked Language Modeling (MLM).
  • Procedure: Leveraged TPU v3 or TPU v4 pods using the T5x codebase with JAX.
  • Original Paper: Offers comprehensive details on training methodologies and model evaluation.

Guide: Running Locally

  1. Install Dependencies:

    pip install transformers accelerate
    
  2. Load Model and Tokenizer:

    from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
    
    tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
    model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8")
    
  3. Prepare Input and Generate Output:

    input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
    outputs = model.generate(input_ids)
    print(tokenizer.decode(outputs[0]))
    
  4. Running on GPU:

    • Modify the model loading to use GPU:
      model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", device_map="auto")
      input_ids = input_ids.to(0)
      
  5. Cloud GPUs:

    • Consider using cloud services like Google Cloud Platform (GCP) which provides access to TPU Pods for improved performance.

License

Switch Transformers is licensed under the Apache 2.0 License, allowing for wide usage and modification with proper attribution.

More Related APIs in Text2text Generation