Bamba 9 B
ibm-fmsIntroduction
Bamba-9B is a decoder-only language model designed for diverse text generation tasks. It is based on the Mamba-2 architecture and undergoes a two-stage training process to refine its performance.
Architecture
The Bamba-9B model consists of 9.78 billion parameters, with 32 layers, a hidden dimension of 4096, and 32 attention heads. The context length is 4096, and the model does not use tied embeddings.
Training
Training Bamba-9B utilizes FSDP1 with the official Mamba implementation. The first stage involves training on 2 trillion tokens, followed by a second stage with 200 billion tokens of high-quality data. Users can now leverage HF-version of Mamba2-Hybrid for training.
Guide: Running Locally
-
Installation:
- Install PyTorch and additional dependencies:
git clone https://github.com/Dao-AILab/causal-conv1d.git cd causal-conv1d && pip install . && cd .. git clone https://github.com/state-spaces/mamba.git cd mamba && pip install . && cd .. git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention && pip install . && cd ..
- Install the latest transformers package:
pip install git+https://github.com/huggingface/transformers.git
- Install PyTorch and additional dependencies:
-
Inference:
from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B") tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B") message = ["Mamba is a snake with following properties "] inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) response = model.generate(**inputs, max_new_tokens=64) print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
-
Cloud GPUs: Utilize cloud GPU services like AWS, Google Cloud, or Azure for running heavy models efficiently.
License
Bamba-9B is licensed under the Apache-2.0 license, allowing for wide use and distribution with proper attribution.