Reddit memes pixtral 12 B v4

AlexandrosChariton

Introduction

The Reddit-Memes-Pixtral-12B-v4 model is a custom adaptation of the Pixtral-12B model from the mistral-community. This model processes both images and text to generate engaging comments for Reddit posts in the subreddit r/memes. It is designed to create content that may maximize user engagement, reflecting the style of popular Reddit comments.

Architecture

The model is based on Pixtral-12B, a powerful multimodal language model capable of processing and generating both text and image-based content. This implementation utilizes LoRA (Low-Rank Adaptation) to fine-tune the original model, focusing on high-engagement comments from Reddit.

Training

The model was fine-tuned using a dataset of Reddit comments that received significant upvotes. The dataset comprised 1.5k posts and 12k comments, with a focus on quality and engagement. Approximately 3% of the model’s parameters were trainable during the fine-tuning process. The training involved basic filtering to eliminate low-quality or controversial content, aiming to avoid inappropriate or misleading outputs.

Guide: Running Locally

  1. Setup Environment:

    • Ensure you have Python installed along with necessary libraries like transformers, torch, and PIL.
  2. Install Dependencies:

    pip install transformers torch pillow
    
  3. Load the Model and Processor:

    from peft import PeftModel, PeftConfig
    from transformers import AutoProcessor, LlavaForConditionalGeneration
    from PIL import Image
    import torch
    
    model_id = "mistral-community/pixtral-12b"
    model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda")
    processor = AutoProcessor.from_pretrained(model_id)
    
    peft_config = PeftConfig.from_pretrained("AlexandrosChariton/Reddit-memes-pixtral-12B-v4")
    lora_model = PeftModel.from_pretrained(model, "AlexandrosChariton/Reddit-memes-pixtral-12B-v4")
    
  4. Prepare Image and Generate Comments:

    image_path = "meme_image.png"
    meme_title = "I hate it when this happens to me"
    
    image = Image.open(image_path).convert("RGB")
    image.thumbnail((512, 512))
    
    PROMPT = f"<s>[INST]Come up with a comment that will get upvoted by the community for a reddit post in r/memes. Provide the comment body with text and nothing else. The post has title: '{meme_title}' and image:\n[IMG][/INST]"
    inputs = processor(text=PROMPT, images=image, return_tensors="pt").to("cuda")
    
    generate_ids = lora_model.generate(**inputs, max_new_tokens=650)
    output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    print(output)
    
  5. Cloud GPUs: Consider using cloud GPU services like AWS, GCP, or Azure for efficient model inference due to the high computational resources required.

License

The model is released under the MIT License, allowing for flexible use, modification, and distribution with appropriate credit to the original author.

More Related APIs in Text Generation