persona G P T
af1tangIntroduction
PersonaGPT is an open-domain conversational agent capable of generating personalized responses and incorporating turn-level goals through "action codes." It extends the DialoGPT-medium model, which is based on the GPT-2 architecture, and is fine-tuned on the Persona-Chat dataset. The model uses special tokens to manage conversational history and personality traits effectively.
Architecture
PersonaGPT builds upon the GPT-2 architecture and leverages the DialoGPT-medium model. It introduces special tokens to help distinguish between conversational history and persona profiles. The model is trained using active learning to perform controlled decoding with turn-level goals.
Training
The model is fine-tuned on the Persona-Chat dataset, which includes personality facts and dyadic conversation data. Special tokens and action codes were introduced to enhance the model's ability to generate personalized and context-aware responses.
Guide: Running Locally
-
Load the model and define helper functions:
from transformers import GPT2Tokenizer, GPT2LMHeadModel import torch tokenizer = GPT2Tokenizer.from_pretrained("af1tang/personaGPT") model = GPT2LMHeadModel.from_pretrained("af1tang/personaGPT") if torch.cuda.is_available(): model = model.cuda()
-
Set up chatbot personalities:
personas = [] for i in range(3): response = input(">> Fact %d: " % (i+1)) + tokenizer.eos_token personas.append(response) personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))
-
Interact with the model for personalized dialog generation:
dialog_hx = [] for step in range(8): user_inp = tokenizer.encode(input(">> User: ") + tokenizer.eos_token) dialog_hx.append(user_inp) bot_input_ids = to_var([personas + flatten(dialog_hx)]).long() msg = generate_next(bot_input_ids) dialog_hx.append(msg) print("Bot: {}".format(tokenizer.decode(msg, skip_special_tokens=True)))
-
Use controlled response generation by selecting actions:
action_space = [ 'ask about kids.', "ask about pets.", 'talk about work.', 'ask about marital status.', 'talk about travel.', 'ask about age and gender.', 'ask about hobbies.', 'ask about favorite food.', 'talk about movies.', 'talk about music.', 'talk about politics.'] dialog_hx = [] for step in range(8): act = None while act not in action_space: display_dialog_history(dialog_hx) print(" actions: ") for k,v in enumerate(action_space): print(k, v) try: act = action_space[int(input(" input [0-10]: "))] except: act = None action_prefix = tokenizer.encode(''.join(['<|act|> '] + [act] + ['<|p1|>'] + [] + ['<|sep|>'] + ['<|start|>'])) bot_input_ids = to_var([action_prefix + flatten(dialog_hx)]).long() msg = generate_next(bot_input_ids) dialog_hx.append(msg) bot_input_ids = to_var([personas + flatten(dialog_hx)]).long() msg = generate_next(bot_input_ids) dialog_hx.append(msg) display_dialog_history(dialog_hx)
For optimal performance, consider using cloud GPUs such as Google Colab or AWS EC2.
License
PersonaGPT is released under the GPL-3.0 license, making it freely available for use and modification under the same license terms.