switch base 8
googleIntroduction
Switch Transformers is a Mixture of Experts (MoE) model that applies a novel architecture for efficient language modeling. It is similar to the T5 model but uses Sparse MLP "experts" in its Feed Forward layers. This design aims to improve training speed and model performance on fine-tuned tasks.
Architecture
- Model Type: Language model
- Language: English
- Key Components: Sparse MLP layers with "experts"
- License: Apache 2.0
- Related Models: Other Switch Transformers Checkpoints available
- Resources: Research paper, GitHub repository, and Hugging Face documentation
Training
- Data: Trained on the Colossal Clean Crawled Corpus (C4) dataset using Masked Language Modeling (MLM).
- Procedure: Leveraged TPU v3 or TPU v4 pods using the T5x codebase with JAX.
- Original Paper: Offers comprehensive details on training methodologies and model evaluation.
Guide: Running Locally
-
Install Dependencies:
pip install transformers accelerate
-
Load Model and Tokenizer:
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8") model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8")
-
Prepare Input and Generate Output:
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>." input_ids = tokenizer(input_text, return_tensors="pt").input_ids outputs = model.generate(input_ids) print(tokenizer.decode(outputs[0]))
-
Running on GPU:
- Modify the model loading to use GPU:
model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", device_map="auto") input_ids = input_ids.to(0)
- Modify the model loading to use GPU:
-
Cloud GPUs:
- Consider using cloud services like Google Cloud Platform (GCP) which provides access to TPU Pods for improved performance.
License
Switch Transformers is licensed under the Apache 2.0 License, allowing for wide usage and modification with proper attribution.