ru Dialo G P T medium

t-bank-ai

Introduction

The ruDialoGPT-medium model is a conversational AI model developed by T-BANK-AI, built on the sberbank-ai/rugpt3medium model. It is designed for generating dialogue in Russian and is based on the GPT-2 architecture. The model is well-suited for creating generative conversational agents and has been trained on a large corpus of dialogue data.

Architecture

The ruDialoGPT-medium model utilizes the GPT-2 architecture, focusing on generating coherent and contextually relevant conversational responses. It operates with a context size of 3, meaning it considers the last three conversational turns when generating a response.

Training

The model was evaluated on a private validation set using metrics from the paper arXiv:2001.09977. These metrics include:

  • Sensibleness: Whether the model's response makes sense given the context.
  • Specificity: Whether the response is specific to the context, avoiding generic responses.
  • SSA (Sensibleness Specificity Average): The average of the sensibleness and specificity scores.

Performance comparisons with other models:

  • ruDialoGPT-small: Sensibleness 0.64, Specificity 0.5, SSA 0.57
  • ruDialoGPT-medium: Sensibleness 0.78, Specificity 0.69, SSA 0.735

Guide: Running Locally

To use the ruDialoGPT-medium model locally, follow these steps:

  1. Install Dependencies: Ensure you have PyTorch and transformers installed.

    pip install torch transformers
    
  2. Import Libraries:

    import torch
    from transformers import AutoTokenizer, AutoModelWithLMHead
    
  3. Load the Model and Tokenizer:

    tokenizer = AutoTokenizer.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
    model = AutoModelWithLMHead.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
    
  4. Prepare Input and Generate Output:

    inputs = tokenizer('@@ПЕРВЫЙ@@ привет @@ВТОРОЙ@@ привет @@ПЕРВЫЙ@@ как дела? @@ВТОРОЙ@@', return_tensors='pt')
    generated_token_ids = model.generate(
        **inputs,
        top_k=10,
        top_p=0.95,
        num_beams=3,
        num_return_sequences=3,
        do_sample=True,
        no_repeat_ngram_size=2,
        temperature=1.2,
        repetition_penalty=1.2,
        length_penalty=1.0,
        eos_token_id=50257,
        max_new_tokens=40
    )
    context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]
    
  5. Output: The variable context_with_response will contain the model's generated responses.

For optimal performance, consider using cloud GPU services like AWS, Google Cloud, or Azure to handle computational requirements efficiently.

License

The ruDialoGPT-medium model is licensed under the MIT License, allowing for extensive use and modification with minimal restrictions.

More Related APIs