twitter roberta base dec2021 tweet topic multi all

cardiffnlp

Introduction

The CardiffNLP/TWITTER-ROBERTA-BASE-DEC2021-TWEET-TOPIC-MULTI-ALL model is a fine-tuned version of the cardiffnlp/twitter-roberta-base-dec2021 model. It is specifically tailored for text classification tasks and is trained on the tweet_topic_multi dataset. The model has been fine-tuned on the train_all split and validated on the test_2021 split of the dataset.

Architecture

This model utilizes the RoBERTa architecture, which is part of the Transformers library and is implemented using PyTorch. The model is configured for multi-label classification, accommodating the diverse topics found in tweets.

Training

The model achieves the following results on the test_2021 set:

  • F1 (micro): 0.7648
  • F1 (macro): 0.6187
  • Accuracy: 0.5485

The fine-tuning script used for training can be accessed here.

Guide: Running Locally

To run the model locally, follow these steps:

  1. Install Dependencies
    Ensure you have Python installed, along with the transformers and torch libraries.

    pip install transformers torch
    
  2. Import and Load the Model
    Use the following code to load and run the model:

    import math
    import torch
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    
    def sigmoid(x):
        return 1 / (1 + math.exp(-x))
    
    tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-dec2021-tweet-topic-multi-all")
    model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-dec2021-tweet-topic-multi-all", problem_type="multi_label_classification")
    model.eval()
    class_mapping = model.config.id2label
    
    with torch.no_grad():
        text = "#NewVideo Cray Dollas- Water- Ft. Charlie Rose- (Official Music Video)- {{URL}} via {@YouTube@} #watchandlearn {{USERNAME}}"
        tokens = tokenizer(text, return_tensors='pt')
        output = model(**tokens)
        flags = [sigmoid(s) > 0.5 for s in output[0][0].detach().tolist()]
        topic = [class_mapping[n] for n, i in enumerate(flags) if i]
    print(topic)
    
  3. Consider Cloud GPUs
    For improved performance, especially with large datasets or batches, consider using cloud-based GPU services such as AWS EC2, Google Cloud, or Azure.

License

Please refer to the respective model and dataset licenses on the Hugging Face model page for usage stipulations.

More Related APIs in Text Classification