pl_prompt_sst

An example project using OpenPrompt under the framework of pytorch-lightning for a training prompt-based text classification model on SST2 sentiment analysis dataset. Leveraging the pytorch-lightning features like logging, gradient accumulation and early stopping, etc. Can be used as a template for further development.

Run

Install requirement

pip install -r requirements.txt

Setup the prompt to use in sst2/prompt_config.json

{
    "template_text": "{\"placeholder\": \"text_a\"} In summary, the film was {\"mask\"}.",
    "label_words": [["bad"], ["good"]]
}

Adjust the arguments in run.sh or the code below for your need, and run it.

CUDA_VISIBLE_DEVICES=0 python -u main.py --input_dir ./sst2 \
                                         --prompt_config_dir ./sst2/prompt_config.json \
                                         --model_class bert \
                                         --model_name_or_path prajjwal1/bert-tiny \
                                         --lr 2e-4
                                         --bs 32 \
                                         --max_seq_length 64 \
                                         --patience 4 \
                                         --accumulation 2 \
                                         --seed 666

In my preliminary experiment with the settings above, the model achieve 0.822 F1 compared to 0.820 without prompt.

Note

Can only be executed after this fix on state_dict()

GitHub

View Github