ru Dialo G P T medium
t-bank-aiIntroduction
The ruDialoGPT-medium model is a conversational AI model developed by T-BANK-AI, built on the sberbank-ai/rugpt3medium model. It is designed for generating dialogue in Russian and is based on the GPT-2 architecture. The model is well-suited for creating generative conversational agents and has been trained on a large corpus of dialogue data.
Architecture
The ruDialoGPT-medium model utilizes the GPT-2 architecture, focusing on generating coherent and contextually relevant conversational responses. It operates with a context size of 3, meaning it considers the last three conversational turns when generating a response.
Training
The model was evaluated on a private validation set using metrics from the paper arXiv:2001.09977. These metrics include:
- Sensibleness: Whether the model's response makes sense given the context.
- Specificity: Whether the response is specific to the context, avoiding generic responses.
- SSA (Sensibleness Specificity Average): The average of the sensibleness and specificity scores.
Performance comparisons with other models:
- ruDialoGPT-small: Sensibleness 0.64, Specificity 0.5, SSA 0.57
- ruDialoGPT-medium: Sensibleness 0.78, Specificity 0.69, SSA 0.735
Guide: Running Locally
To use the ruDialoGPT-medium model locally, follow these steps:
-
Install Dependencies: Ensure you have PyTorch and
transformers
installed.pip install torch transformers
-
Import Libraries:
import torch from transformers import AutoTokenizer, AutoModelWithLMHead
-
Load the Model and Tokenizer:
tokenizer = AutoTokenizer.from_pretrained('tinkoff-ai/ruDialoGPT-medium') model = AutoModelWithLMHead.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
-
Prepare Input and Generate Output:
inputs = tokenizer('@@ПЕРВЫЙ@@ привет @@ВТОРОЙ@@ привет @@ПЕРВЫЙ@@ как дела? @@ВТОРОЙ@@', return_tensors='pt') generated_token_ids = model.generate( **inputs, top_k=10, top_p=0.95, num_beams=3, num_return_sequences=3, do_sample=True, no_repeat_ngram_size=2, temperature=1.2, repetition_penalty=1.2, length_penalty=1.0, eos_token_id=50257, max_new_tokens=40 ) context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]
-
Output: The variable
context_with_response
will contain the model's generated responses.
For optimal performance, consider using cloud GPU services like AWS, Google Cloud, or Azure to handle computational requirements efficiently.
License
The ruDialoGPT-medium model is licensed under the MIT License, allowing for extensive use and modification with minimal restrictions.