Fine-tuning GPT-2 with aitextgen

Preparation

python -m pip install --user aitextgen
In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # tokenizer optimization, only applies to Jupyter notebooks
import random

import torch

from aitextgen import aitextgen
from transformers import GPT2TokenizerFast
from aitextgen.TokenDataset import TokenDataset

import sys
sys.stderr = open(os.devnull, "w") # suppress annoying messages in Jupyter notebooks
sys.stderr.flush()

import multiprocessing
cpu_cores = multiprocessing.cpu_count()

Training options

Training options for fine-tuning a pretrained GPT-2 model.

In [3]:
minibatch_size = 1       # how many samples to train on in one training step (if you run out of memory, lower this value)
total_batch_size = 1000  # how many samples to take before taking an optimizer step
warmup_steps = 5000      # eases the model into training to stabilize the gradient

# gradient accumulation improves training efficiency by stabilizing the gradient
gradient_accumulation_steps = total_batch_size // minibatch_size

learning_rate = 5e-5     # note: you'll need to train on a larger dataset with a learning rate of 1e-8 to get good results

# note: if you run out of memory with a minibatch_size of 1 and gradient checkpointing enabled
#       try lowering this
max_length = 128

# max_length controls how much of the model's context is trained
# GPT-2's context size is 1024, so if you set this to 256, it will only train 256 tokens at a time
# the memory required by model is O(max_length**2)

# output directory to save the trained model
output_dir = "dorothy-gpt2"

# gpt2 pretrained model name
model_name = 'gpt2' # gpt2 or gpt2-medium

# number of threads loading data
num_workers = min(minibatch_size, cpu_cores)

Concatenating training data

To use multiple training files with aitextgen easily, they need to be concatenated together.

In [3]:
# text file to use for training
training_filenames = ["anime-transcripts.valid.txt", "dorothy.txt"]
# note: we're using the anime-transcripts validation set for training since it's small
#       and because aitextgen doesn't support checking validation loss

training_filename = "gpt2_training.txt"
f = open(training_filename, "w")
for filename in training_filenames:
    data = open(filename).read()
    f.write(data)
f.close()
    
assert os.stat(training_filename).st_size != 0, f"Training file {training_filename} is empty."

Creating a pretrained tokenizer

A tokenizer is in charge of preparing inputs for a model, by converting text into tokens with unique ID numbers.

In [4]:
# create a tokenizer from a pretrained GPT2 tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

Using a tokenizer

Tokenizers can operate on either a string or list of strings and either return a list of token ids or tensors.

When encoding input for a model, it must be converted into a tensor with a batch dimension. You can do this by specifying return_tensors="pt"

In [7]:
# the tokenizer turns text into byte-pair encoded tokens (technically they're character-paired)
encoded = tokenizer.encode("Dorothy: Hey, honey!")
decoded = tokenizer.decode(encoded)
print("Encoded tokens:", encoded)
print("   ID  Text")
for token in encoded:
    print(f"{token:>5}: {tokenizer.decode([token])}")
print("Decoded tokens:", decoded)

model_input = tokenizer.encode("Dorothy: Hey, honey!", return_tensors="pt")
print("---")
print("model_input:", model_input)
print("model_input.shape:", model_input.shape)
Encoded tokens: [35, 273, 14863, 25, 14690, 11, 12498, 0]
   ID  Text
   35: D
  273: or
14863: othy
   25: :
14690:  Hey
   11: ,
12498:  honey
    0: !
Decoded tokens: Dorothy: Hey, honey!
---
model_input: tensor([[   35,   273, 14863,    25, 14690,    11, 12498,     0]])
model_input.shape: torch.Size([1, 8])

Creating and loading models

If the training model doesn't exist, create a new one from a pretrained GPT-2 model.

In [6]:
# create a new model or load an existing one
if not os.path.exists(output_dir):
    # create the text generator and place the model on the GPU if CUDA is available
    print("Created new model")
    textgen = aitextgen(
        model_name,
        to_gpu=torch.has_cuda,
        gradient_checkpointing=True # if you have lots of memory disable this
    )
else:
    print("Loaded model from", output_dir)
    # load the model if it already exists
    textgen = aitextgen(
        model=os.path.join(output_dir, "pytorch_model.bin"),
        config=os.path.join(output_dir, "config.json"),
        tokenizer=tokenizer,
        to_gpu=torch.has_cuda,
        gradient_checkpointing=True
    )
    textgen.model.train()
Loaded model from dorothy-gpt2

GPU Information

In [7]:
# print nvidia GPU information, only works in Jupyter notebook
if torch.has_cuda:
    !nvidia-smi
# note: this notebook requires at least 6 GB of free memory
Thu Apr  1 13:36:54 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
...

Creating a dataset

Datasets used in training are Python iterables that iterate over the training data each training step and return a minibatch.

In [8]:
# you can build datasets for training by creating TokenDatasets
dataset = TokenDataset(
    training_filename,
    tokenizer=tokenizer,    
    block_size=max_length,
    num_workers=num_workers,
    pin_memory=torch.has_cuda
)

Training

Warning: aitextgen does not support validation sets to check that the model isn't overfitting.

num_steps = 16 will take roughly 15 minutes to pass over the dataset 3 times.

In [4]:
# training settings
generate_every = gradient_accumulation_steps * 1  # generate a sample every optimizer step
save_every = gradient_accumulation_steps * 8      # save every 8 optimizer steps
num_steps = 16                                    # finish after 16 optimizer steps
weight_decay = 0 # weight decay causes catastrophic forgetting during fine-tuning and aitextgen uses it by default

To calculate the perplexity during training: n = exp( loss )

Perplexity describes how confused the model is at predicting a sample, equivalently to as though it had to choose uniformly and independently among n possibilities for each word.

Fine-tuning below a loss of 2.5 is not recommended and an indication the model is overfitting.

In [11]:
# free unused memory
import gc
gc.collect()
if torch.has_cuda:
    torch.cuda.empty_cache() 

# increase dropout to prevent overfitting
dropout = 0.25
textgen.model.config.resid_pdrop = dropout
textgen.model.config.attn_pdrop = dropout
textgen.model.config.embd_pdrop = dropout

# train the model, it will save the model periodically and after completion
textgen.train(
    dataset,
    output_dir=output_dir,
    weight_decay=weight_decay, # weight decay causes catastrophic forgetting during fine-tuning
    batch_size=minibatch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    generate_every=generate_every,
    save_every=save_every,
    num_steps=num_steps,
    num_workers=num_workers
)
# note: aitextgen's training progress bar is bugged with gradient accumulation
1
2,000 steps reached: generating sample texts.
==========
Kaname: You're so powerful!
Kaname: You're the only person in the world! This is the only time you're actually able to run the tournament!
Kaname: Your power has gone up!
Kaname: You can't get out of your way!
Kaname: You're a loser!
Kaname: You've been in a match! Kaname!
Kaname: You're going to be our best team!
Kaname: You're not even the most skilled! It's going to be a big tournament!
Kaname: It's going to be a tournament! You'll get the title!
Kaname: You're a winner! You're the first to win! You're a loser! You're a contest! You's going to get the head away! This is a tournament!
Kaname: It's gonna be a fight! You've got a good chance to win!
Kaname: The next match!
Kaname: I have to win!
Kaname: I can't see the winner!
Kaname: We're going to be the last match!
Kaname:
==========
4,000 steps reached: generating sample texts.
...

Generating text with aitextgen

Generating text with aitextgen is very simple.

In [23]:
textgen.model.eval() # always call this before sampling, otherwise it will inference the model with dropout

# bug: model evaluates on cpu
if torch.has_cuda:
    textgen.model = textgen.model.cuda()
    
textgen.generate(3, prompt="Dorothy: Hey, honey!\nJill: Oh! Hey, Dorothy.\nDorothy: Can I", max_length=max_length)
Dorothy: Hey, honey!
Jill: Oh! Hey, Dorothy.
Dorothy: Can I get something to drink?
Jill: Sure, I'm drinking.
Dorothy: Have you ever had a drink in the past?
Jill: Sure.
Dorothy: So, how long has it been since you've been in here?
Jill: About a week.
Dorothy: Oh, that's right.
Jill: You've been so busy the last couple of days, I've been feeling really sleepy.
Dorothy: Huh?
==========
Dorothy: Hey, honey!
Jill: Oh! Hey, Dorothy.
Dorothy: Can I ask you one question?
Jill: No, not a question.
Dorothy: What for?
Jill: Um... "What for?"
Dorothy: "Are you a doctor?"
Jill: And I'm a little curious.
Dorothy: "Do you have a job"? Or is that what some people call it?
Dorothy: I don't know.
Dorothy: "Do you have any hobbies?"
Dorothy:
==========
Dorothy: Hey, honey!
Jill: Oh! Hey, Dorothy.
Dorothy: Can I look at you later maybe?
Jill: Sure. I'll be back.
Jill: Sounds good to me. Can you give me a call back later?
Dorothy: Sure.
Jill: Who should I call?
Dorothy: I'm sorry, but I need to talk to you about something.
Dorothy: I need to talk to you.
Dorothy: Hey, Dorothy.
Dorothy: I wanted to talk to you about

Generating text with more control

aitextgen doesn't provide a way to change the top-k and top-p sampling so we have to sample the transformers model directly.

The generate function will return tokens instead of text, which we'll be using the next section.

In [26]:
# we will write our own generate function that returns tokens
def generate(model, input_ids, top_k=32, top_p=0.9, temperature=0.9, n=3, max_length=128):
    model.eval()
    return textgen.model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=n,
    )

prompts = [
    "Dorothy: Hey, honey!\nAnon: Tell me something interesting, Dorothy.\nDorothy:",
    "Haruhi: We're gonna play baseball today!\nKirito: Sorry, but I'm not cool with that plan.\nHaruhi:",
    "Dorothy: How do you know what's real?\nMorpheus: If real is what you can feel, smell, taste and see, then real is simply electrical signals interpreted by your brain.\nDorothy:"
]

# pretty colors
C_BOLD = "\033[1m"
C_ENDC = "\033[0m"

for prompt in prompts:
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(textgen.model.device)
    out = generate(textgen.model, input_ids, max_length=max_length, n=1)
    for encoded in out:
        print(C_BOLD + prompt + C_ENDC + tokenizer.decode(encoded[input_ids.size(1):]))
        print("-----")
Dorothy: Hey, honey!
Anon: Tell me something interesting, Dorothy.
Dorothy: Do you have a favorite book?
Anon: Um... Um... Actually, I think the books I have are a lot more interesting.
Dorothy: No, it's not like that! I can't wait to read this!
Anon: That's all right. Let's see...
Jill: What was that about?
Dorothy: A letter of apology, and I... I'm so sorry!
Jill: I really don't know what to do, but
-----
Haruhi: We're gonna play baseball today!
Kirito: Sorry, but I'm not cool with that plan.
Haruhi: We'll play baseball today. You're not gonna be late. I'm gonna tell you everything I know about the game. It's gonna be a long day.
Kyon: The day before the game was supposed to begin. And it felt like it was just beginning. I was a little confused, thinking about something. I was gonna be on a bus, and I couldn't decide whether or not to stop at any point, so I just sat there on the bench
-----
Dorothy: How do you know what's real?
Morpheus: If real is what you can feel, smell, taste and see, then real is simply electrical signals interpreted by your brain.
Dorothy: If you're able to control your mind, you'll feel it, too.
Jill: No... I think that might be a bit silly.
Dorothy: But I do feel like I have a little bit of control over my own mind.
Jill: No! I just don't have control over anything, do I?
Dorothy: But what if I
-----

Interact with the model

This is a simple chatbot example to interact with the model.

Summary of the chat loop

  1. Get user input.
  2. Add it to the chat history + a new line and f'{waifu}:' to start generating from.
  3. Encode the chat history into tokens.
  4. Only some of the chat history tokens are used to leave room for the model to generate a response.
  5. Generate response tokens from chat history tokens.
  6. End the response at newlines and <|endoftext|> tokens.
  7. Update chat.
In [28]:
import time

response_length = 32  # maximum amount of tokens the chatbot can generate

# ask for user name and waifu name
name = input("What's your name? ")
waifu = input("And who is your waifu? ")

eos = tokenizer.encode("\n")[0]
eot = tokenizer.encode("<|endoftext|>")[0]
chat_history = ""     # chat_history could be seeded with previous chat history or a prompt to improve generation

# pretty colors
C_RED = "\033[91m"
C_GREEN = "\033[92m"
C_BOLD = "\033[1m"
C_ENDC = "\033[0m"
C_USER = C_BOLD + C_GREEN
C_WAIFU = C_BOLD + C_RED

# get user input
user_input = input(f"{C_USER}{name}{C_ENDC}: ")
while user_input != "": # leaving input blank will quit   
    # add it to the chat history + a new line and f'{waifu}:' to start generating from
    chat_history += f"{name}: {user_input}\n{waifu}:"
    
    # encode the chat history into tokens
    chat_history_tokens = tokenizer.encode(chat_history, return_tensors="pt")
    
    # only some of the chat history tokens are used to leave room for the model to generate a response
    chat_history_tokens = chat_history_tokens[:, -max_length+response_length:]
    chat_history_tokens = chat_history_tokens.to(textgen.model.device)  # place them on the GPU if necessary
    
    # generate response tokens from chat history tokens
    response_tokens = generate(textgen.model, chat_history_tokens, n=1)
    response_tokens = response_tokens[:, chat_history_tokens.size(1):] # we only want the new tokens generated by GPT-2
    
    # end the response at newlines and <|endoftext|> tokens
    response_tokens_list = list(response_tokens.cpu().numpy()[0])
    try:
        eos_index = response_tokens_list.index(eos)
    except ValueError:
        eos_index = -1
    try:
        eot_index = response_tokens_list.index(eot)
    except ValueError:
        eot_index = -1    
    end_index = -1
    if eos_index > 0:
        end_index = eos_index
    if eot_index > 0 and eot_index < eos_index:
        end_index = eot_index
    
    response_text = tokenizer.decode(response_tokens[0, :end_index])
    
    # update chat
    chat_history += f"{waifu}:{response_text}\n"
    print(f"{C_WAIFU}{waifu}{C_ENDC}:{response_text}")
    time.sleep(0.25) # prevent race condition in Jupyter notebook
    
    # get user input
    user_input = input(f"{C_USER}{name}{C_ENDC}: ")
What's your name? Anon
What's your waifu's name? Dorothy
Anon: Hey Dorothy!
Dorothy: Hello?
Anon: Do you like mudkips?
Dorothy: Um...
Anon: I heard you like mudkips...
Dorothy: Um...
Anon: It's okay I like mudkips too.
Dorothy: Oh, no!
Anon: 
Dorothy:...