This repository provides the code for MedViLL(Medical Vision Language Learner).
Our proposed architecture MedViLL is a single BERT-based model that learns unified contextualized vision-language (VL) representation for both Vision Language Understanding (VLU) and Vision Language Generation (VLG). MedViLL performs pre-training with a CNN-based visual encoder and a cross-modal Transformer for VL joint representation learning. After pre-training, our model can be easily used for VLU and VLG tasks with task-specific finetuning. Please refer to our paper “Multi-modal Understanding and Generation for Medical Images and Text via Vision-Language Pre-Training” for more details.
We provide five versions of BERT-based pre-trained weights with different types of self-attention masks. Pre-training for the joint embedding was built on the BERT-base architecutre(12 hidden layers, 12 attention heads, 768 hidden size), and training details are described in our paper. Currently avaliable versions of pre-trained weights are as follows:
MedViLL – BERT-Base model with Bidirectional Auto-regressive attention mask.
Bi & Seq2Seq – BERT-Base model with Seq2Seq attention mask(75%) and Bidirectional attention mask(25%) in every mini-batch.
Bidirectional – BERT-Base model with Bidirectional attention mask.
Seq2Seq – BERT-Base model with Seq2Seq attention mask.
Non-cross – BERT-Base model with Non-cross modality attention mask.
We provide a pre-processed version of multiple datasets for each task as follows:
Download each dataset to the path /data/[dataset].
- MIMIC-CXR (2.27 GB): Unique study of 91,685 AP view image and associated report pairs.
- OPEN-I (74.1 MB): Unique study of 3,547 AP and PA image-report pairs from the official Open-I dataset.
- VQA-RAD (402 MB): 3,515 question answer pairs on 315 images (104 head CTs or MRIs, 107 Chest X-rays, and 104 abdominal CTs).
We also provide the JSON file with the path for validation in the retrieval task, download each files to the path /data/[dataset].
Image to report retrieval
Report to Image retrieval
Section A. Installation
Sections below describe the virtual env installation and the fine-training process of MedviLL based on pytorch version 1.7, python version 3.8.
To fine-tune MedViLL, you need to download the pre-trained weights of MedViLL. After downloading the pre-trained weights, use medvill.yaml to install conda based virtual env as follows:
$ git clone https://github.com/SuperSupermoon/MedViLL.git $ cd MedViLL; conda env create --file medvill.yaml
Note that all fine-tuning models were conducted on 8 Geforce RTX-3090 GPU machines, each of which has 24GB of VRAM.
Section B. Prepare pre-processed dataset
Unzip mimic, openi, and VQA-RAD tar.gz files.
$ cd MedViLL; tar -zxvf [file_name.tar.gz]
Section C. Pre-training model
$ cd MedViLL $ python main.py
Section D. Downstream model
- Diagnosis Classification
$ cd MedViLL/downstream_task/classification $ python cls.py
- Image-Report Retrieval
$ cd MedViLL/downstream_task/retrieval $ python retrieval.py
- Medical Visual Qestion Answering
$ cd MedViLL/downstream_task/report_generation_and_vqa $ python finetune.py --tasks vqa --s2s_prob 0 --bi_prob 1 --mask_prob 0
- Report Generation
$ cd MedViLL/downstream_task/report_generation_and_vqa $ python finetune.py --tasks report_generation --mask_prob 0.15 --s2s_prob 1 --bi_prob 0