roberta fake news classification

hamzab

Introduction

This project features a RoBERTa-based model fine-tuned for fake news classification. The model is trained to accurately distinguish between true and fake news articles using the Kaggle fake-and-real-news-dataset. The model accepts input in the form of a news article with a title and content and predicts its authenticity.

Architecture

The model is based on the roberta-base architecture, specifically fine-tuned for the task of text classification to identify fake news. It utilizes the Transformers library and is implemented in PyTorch to achieve its classification goals.

Training

The model was trained on the Kaggle dataset for fake and real news, achieving 100% accuracy on the dataset. The training process involves optimizing the model's parameters to minimize prediction errors on this specific task.

Guide: Running Locally

To use the model locally, follow these steps:

  1. Install the Transformers library:

    pip install transformers torch
    
  2. Download the model:

    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    
    tokenizer = AutoTokenizer.from_pretrained("hamzab/roberta-fake-news-classification")
    model = AutoModelForSequenceClassification.from_pretrained("hamzab/roberta-fake-news-classification")
    
  3. Make predictions:

    import torch
    
    def predict_fake(title, text):
        input_str = "<title>" + title + "<content>" + text + "<end>"
        input_ids = tokenizer.encode_plus(input_str, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model.to(device)
        with torch.no_grad():
            output = model(input_ids["input_ids"].to(device), attention_mask=input_ids["attention_mask"].to(device))
        return dict(zip(["Fake", "Real"], [x.item() for x in list(torch.nn.Softmax()(output.logits)[0])]))
    
    print(predict_fake("<HEADLINE-HERE>", "<CONTENT-HERE>"))
    
  4. Testing with Gradio:

    import gradio as gr
    
    iface = gr.Interface(fn=predict_fake, inputs=[gr.inputs.Textbox(lines=1, label="headline"), gr.inputs.Textbox(lines=6, label="content")], outputs="label").launch(share=True)
    

Cloud GPUs

For faster inference, consider using cloud GPUs from platforms like AWS, Google Cloud, or Azure. This will accelerate the model's performance, especially for large-scale data processing.

License

The model is released under the MIT License, allowing for wide usage and modification with minimal restrictions.

More Related APIs in Text Classification