TextAugmentation-GPT2

Fine-tuned pre-trained GPT2 for custom topic specific text generation. Such system can be used for Text Augmentation.

Fine-tuned pre-trained GPT2 for topic specific text generation. Such system can be used for Text Augmentation.

Getting Started

  1. git clone https://github.com/prakhar21/TextAugmentation-GPT2.git
  2. Move your data to data/ dir.

* Please refer to data/SMSSpamCollection to get the idea of file format.

Tuning for own Corpus

  1. Assuming are done with Point 2 under Getting Started
2. Run python3 train.py --data_file <filename> --epoch <number_of_epochs> --warmup <warmup_steps> --model_name <model_name> --max_len <max_seq_length> --learning_rate <learning_rate> --batch <batch_size>

Generating Text

1. python3 generate.py --model_name <model_name> --sentences <number_of_sentences> --label <class_of_training_data>

* It is recommended that you tune the parameters for your task. Not doing so may result in choosing default parameters and eventually giving sub-optimal performace.

Quick Testing

I had fine-tuned the model on SPAM/HAM dataset. You can download it from here and follow the steps mentioned under Generation Text section.

Sample Results

SPAM: You have 2 new messages. Please call 08719121161 now. £3.50. Limited time offer. Call 090516284580.<|endoftext|>
SPAM: Want to buy a car or just a drink? This week only 800p/text betta...<|endoftext|>
SPAM: FREE Call Todays top players, the No1 players and their opponents and get their opinions on www.todaysplay.co.uk Todays Top Club players are in the draw for a chance to be awarded the £1000 prize. TodaysClub.com<|endoftext|>
SPAM: you have been awarded a £2000 cash prize. call 090663644177 or call 090530663647<|endoftext|>

HAM: Do you remember me?<|endoftext|>
HAM: I don't think so. You got anything else?<|endoftext|>
HAM: Ugh I don't want to go to school.. Cuz I can't go to exam..<|endoftext|>
HAM: K.,k:)where is my laptop?<|endoftext|>

Important Points to Note

  • Top-k and Top-p Sampling (Variant of Nucleus Sampling) has been used while decoding the sequence word-by-word. You can read more about it here

Note: First time you run, it will take considerable amount of time because of the following reasons -

  1. Downloads pre-trained gpt2-medium model (Depends on your Network Speed)
  2. Fine-tunes the gpt2 with your dataset (Depends on size of the data, Epochs, Hyperparameters, etc)

All the experiments were done on IntelDevCloud Machines

Will soon be updating the documentation! Please raise in issue section for any discrepancy.

GitHub