music classifier

gastonduault

Music Genre Classification Model 🎶

Introduction

This model classifies music genres based on audio signals in the .wav format. It utilizes the Wav2Vec2 architecture and is fine-tuned using the music_genres_small dataset. It is designed for audio classification tasks and provides predictions on various music genres.

Architecture

The model is based on the Wav2Vec2 architecture, specifically leveraging the facebook/wav2vec2-large pre-trained model. It has been fine-tuned to adapt to music genre classification.

Training

  • Dataset: lewtun/music_genres_small
  • Base Model: facebook/wav2vec2-large
  • Metrics:
    • Validation Accuracy: 75%
    • F1 Score: 74%
    • Validation Loss: 0.77

Guide: Running Locally

  1. Environment Setup:

    • Ensure you have Python installed along with necessary libraries such as transformers, librosa, and torch.
  2. Install Required Libraries:

    pip install transformers librosa torch
    
  3. Load the Model:

    from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
    import librosa
    import torch
    
    model = Wav2Vec2ForSequenceClassification.from_pretrained("gastonduault/music-classifier")
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large")
    
  4. Preprocess Audio:

    def preprocess_audio(audio_path):
        audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
        return feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
    
    audio_path = "./Nirvana - Come As You Are.wav"
    inputs = preprocess_audio(audio_path)
    
  5. Predict Genre:

    with torch.no_grad():
        logits = model(**inputs).logits
        predicted_class = torch.argmax(logits, dim=-1).item()
    
    genre_mapping = {
        0: "Electronic",
        1: "Rock",
        2: "Punk",
        3: "Experimental",
        4: "Hip-Hop",
        5: "Folk",
        6: "Chiptune / Glitch",
        7: "Instrumental",
        8: "Pop",
        9: "International",
    }
    
    print(f"Song analyzed: {audio_path}")
    print(f"Predicted genre: {genre_mapping[predicted_class]}")
    
  6. Cloud GPU Suggestion: For enhanced performance, consider using cloud-based GPU services such as AWS, Google Cloud, or Azure to run the model.

License

The model and associated code are subject to applicable licenses. Please refer to the respective repositories and datasets for specific license details.

More Related APIs in Audio Classification