Diarization L M 8b Fisher v2
googleDiarizationLM-8b-Fisher-v2
Introduction
DiarizationLM-8b-Fisher-v2 is a model finetuned on the Fisher corpus for speaker diarization tasks. It is based on the foundation model unsloth/llama-3-8b-bnb-4bit
. The model focuses on computing loss only on completion tokens unlike its predecessor google/DiarizationLM-8b-Fisher-v1
, which computes loss on both prompt and completion tokens.
Architecture
- Foundation Model:
unsloth/llama-3-8b-bnb-4bit
- Parameters: 671,088,640
- Batch Size: 16
- Training Steps: 28,800 (~9 epochs)
- Maximal Prompt Length: 6000 characters
- Maximal Sequence Length: 4096 tokens
Training
The model is finetuned using a LoRA adapter of rank 256 on a subset of the Fisher corpus, utilizing mixed data from hyp2ora
and deg2ref
flavors. The finetuning was conducted over four days using a Google Cloud VM with an NVIDIA A100 GPU (80GB memory).
Guide: Running Locally
-
Install Required Packages:
pip install transformers diarizationlm
-
Run the Model:
from transformers import LlamaForCausalLM, AutoTokenizer from diarizationlm import utils HYPOTHESIS = """<speaker:1> Hello, how are you doing <speaker:2> today? I am doing well. What about <speaker:1> you? I'm doing well, too. Thank you.""" tokenizer = AutoTokenizer.from_pretrained("google/DiarizationLM-8b-Fisher-v2", device_map="cuda") model = LlamaForCausalLM.from_pretrained("google/DiarizationLM-8b-Fisher-v2", device_map="cuda") inputs = tokenizer([HYPOTHESIS + " --> "], return_tensors = "pt").to("cuda") outputs = model.generate(**inputs, max_new_tokens = int(inputs.input_ids.shape[1] * 1.2), use_cache = False) completion = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens = True)[0] transferred_completion = utils.transfer_llm_completion(completion, HYPOTHESIS) print("Hypothesis:", HYPOTHESIS) print("Completion:", completion) print("Transferred completion:", transferred_completion)
-
Consider Using Cloud GPUs: Deploying the model on a machine with GPU capabilities, such as those offered by Google Cloud or AWS, can enhance performance significantly.
License
The model is released under the llama3
license. It is not an officially supported Google product.