Introduction

Bamba-9B is a decoder-only language model designed for diverse text generation tasks. It is based on the Mamba-2 architecture and undergoes a two-stage training process to refine its performance.

Architecture

The Bamba-9B model consists of 9.78 billion parameters, with 32 layers, a hidden dimension of 4096, and 32 attention heads. The context length is 4096, and the model does not use tied embeddings.

Training

Training Bamba-9B utilizes FSDP1 with the official Mamba implementation. The first stage involves training on 2 trillion tokens, followed by a second stage with 200 billion tokens of high-quality data. Users can now leverage HF-version of Mamba2-Hybrid for training.

Guide: Running Locally

  1. Installation:

    • Install PyTorch and additional dependencies:
      git clone https://github.com/Dao-AILab/causal-conv1d.git
      cd causal-conv1d && pip install . && cd ..
      git clone https://github.com/state-spaces/mamba.git
      cd mamba && pip install . && cd ..
      git clone https://github.com/Dao-AILab/flash-attention.git
      cd flash-attention && pip install . && cd ..
      
    • Install the latest transformers package:
      pip install git+https://github.com/huggingface/transformers.git
      
  2. Inference:

    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
    tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
    
    message = ["Mamba is a snake with following properties  "]
    inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
    response = model.generate(**inputs, max_new_tokens=64)
    print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
    
  3. Cloud GPUs: Utilize cloud GPU services like AWS, Google Cloud, or Azure for running heavy models efficiently.

License

Bamba-9B is licensed under the Apache-2.0 license, allowing for wide use and distribution with proper attribution.

More Related APIs in Text Generation