rubert_ext_sum_gazeta
IlyaGusevIntroduction
The rubert_ext_sum_gazeta
model is designed for extractive summarization, specifically tailored to work with articles from Gazeta.ru. It is built on the RuBERT model, a Russian version of BERT, and provides token classification capabilities for summarization tasks.
Architecture
The model is based on the rubert-base-cased
architecture, which is a BERT model adapted for the Russian language. It uses the transformers
library and is implemented in PyTorch for deep learning tasks.
Training
The model was trained using the Gazeta dataset, which consists of articles from Gazeta.ru. Specific details on the training procedure and hyperparameters are yet to be detailed. The model is optimized for extractive summarization, where key sentences are selected from the input text.
Guide: Running Locally
Here's how to run the model locally:
-
Install Required Libraries:
- Ensure you have Python and PyTorch installed.
- Install the
transformers
andrazdel
libraries.
pip install transformers razdel
-
Load the Model:
import razdel from transformers import AutoTokenizer, BertForTokenClassification model_name = "IlyaGusev/rubert_ext_sum_gazeta" tokenizer = AutoTokenizer.from_pretrained(model_name) model = BertForTokenClassification.from_pretrained(model_name)
-
Prepare Input Data:
Split your article into sentences using
razdel
and join them with the separator token.article_text = "..." sentences = [s.text for s in razdel.sentenize(article_text)] sep_token = tokenizer.sep_token article_text = sep_token.join(sentences)
-
Tokenize and Infer:
Tokenize the input and run inference to get logits for sentence selection.
inputs = tokenizer([article_text], max_length=500, padding="max_length", truncation=True, return_tensors="pt") sep_mask = inputs["input_ids"][0] == sep_token_id with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0, :, 1]
-
Select and Print Summary:
Select the top sentences based on logits and print the summary.
logits = logits[sep_mask] logits, indices = logits.sort(descending=True) indices = list(sorted(indices.cpu().tolist()[:3])) summary = " ".join([sentences[idx] for idx in indices]) print(summary)
Cloud GPUs
For improved performance, especially with large datasets or articles, consider using cloud GPUs from providers like AWS, Google Cloud, or Azure.
License
The model is licensed under the Apache 2.0 License, allowing for both personal and commercial use with proper attribution.