hubert base superb er
superbHUBERT-BASE-SUPERB-ER
Introduction
The hubert-base-superb-er
model is an adaptation of S3PRL's Hubert for the SUPERB Emotion Recognition task. It is based on hubert-base-ls960
, pretrained on 16kHz sampled speech audio. It is designed to predict emotion classes from speech utterances using the IEMOCAP dataset.
Architecture
The model utilizes the hubert-base-ls960
as its base architecture, which is pretrained on speech audio sampled at 16kHz. Ensure that input speech to the model is also sampled at this frequency for optimal performance.
Training
The model was trained using the Emotion Recognition (ER) task, which involves predicting emotion classes for each utterance. The IEMOCAP dataset is used, following a conventional evaluation protocol that involves cross-validation on five folds, focusing on four balanced emotion classes.
Guide: Running Locally
Basic Steps
-
Install Required Libraries: Ensure you have
transformers
,torch
,librosa
, anddatasets
installed.pip install transformers torch librosa datasets
-
Load Dataset and Model:
from datasets import load_dataset from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor import librosa def map_to_array(example): speech, _ = librosa.load(example["file"], sr=16000, mono=True) example["speech"] = speech return example dataset = load_dataset("anton-l/superb_demo", "er", split="session1") dataset = dataset.map(map_to_array) model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-er") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
-
Process Inputs and Predict:
inputs = feature_extractor(dataset[:4]["speech"], sampling_rate=16000, padding=True, return_tensors="pt") logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) labels = [model.config.id2label[_id] for _id in predicted_ids.tolist()]
Cloud GPUs
Consider using cloud GPU services such as AWS EC2, Google Cloud, or Azure for handling large datasets and models efficiently.
License
The model is distributed under the Apache 2.0 License.