cvpr20_LESA
This repository contains the implementation of Shared Attention for Multi-label Zero-shot Learning.
In this work, we address zero-shot multi-label learning for recognition all (un)seen labels using a shared multi-attention method with a novel training mechanism.
Prerequisites
- Python 3.x
- TensorFlow 1.8.0
- sklearn
- matplotlib
- skimage
- scipy==1.4.1
Data Preparation
Please download and extract the vgg_19 model (http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz) in ./model/vgg_19
. Make sure the extract model is named vgg_19.ckpt
NUS-WIDE
-
Please download NUS-WIDE images and meta-data into
./data/NUS-WIDE
folder according to the instructions within the folders./data/NUS-WIDE
and./data/NUS-WIDE/Flickr
. -
To extract features into TensorFlow storage format, please run:
python ./extract_data/extract_full_NUS_WIDE_images_VGG_feature_2_TFRecord.py #
data_set
==Train
: create NUS_WIDE_Train_full_feature_ZLIB.tfrecords
python ./extract_data/extract_full_NUS_WIDE_images_VGG_feature_2_TFRecord.py #data_set
==Test
: create NUS_WIDE_Test_full_feature_ZLIB.tfrecords
Please change the data_set
variable in the script to Train
and Test
to extract NUS_WIDE_Train_full_feature_ZLIB.tfrecords
and NUS_WIDE_Test_full_feature_ZLIB.tfrecords
.
Open Images
-
Please download Open Images urls and annotation into
./data/OpenImages
folder according to the instructions within the folders./data/OpenImages/2017_11
and./data/OpenImages/2018_04
. -
To crawl images from the web, please run the script:
python ./download_imgs/asyn_image_downloader.py #
data_set
==train
: download images into./image_data/train/
python ./download_imgs/asyn_image_downloader.py #data_set
==validation
: download images into./image_data/validation/
python ./download_imgs/asyn_image_downloader.py #data_set
==test
: download images into./image_data/test/
Please change the data_set
variable in the script to train
, validation
, and test
to download different data splits.
-
To extract features into TensorFlow storage format, please run:
python ./extract_data/extract_images_VGG_feature_2_TFRecord.py #
data_set
==train
: create train_feature_2018_04_ZLIB.tfrecords
python ./extract_data/extract_images_VGG_feature_2_TFRecord.py #data_set
==validation
: create validation_feature_2018_04_ZLIB.tfrecords
python ./extract_data/extract_test_seen_unseen_images_VGG_feature_2_TFRecord.py #data_set
==test
: create OI_seen_unseen_test_feature_2018_04_ZLIB.tfrecords
Please change the data_set
variable in the extract_images_VGG_feature_2_TFRecord.py
script to train
, and validation
to extract features from different data splits.
Training and Evaluation
NUS-WIDE
-
To train and evaluate zero-shot learning model on full NUS-WIDE dataset, please run:
python ./zeroshot_experiments/NUS_WIDE_zs_rank_Visual_Word_Attention.py
Open Images
-
To train our framework, please run:
python ./multilabel_experiments/OpenImage_rank_Visual_Word_Attention.py #create a model checkpoint in
./results
-
To evaluate zero-shot performance, please run:
python ./zeroshot_experiments/OpenImage_evaluate_top_multi_label.py #set
evaluation_path
to the model checkpoint created in step 1) above
Please set the evaluation_path
variable to the model checkpoint created in step 1) above
Model Checkpoint
We also include the checkpoint of the zero-shot model on NUS-WIDE for fast evaluation (./results/release_zs_NUS_WIDE_log_GPU_7_1587185916d2570488/
)
Citation
If this code is helpful for your research, we would appreciate if you cite the work:
@article{Huynh-LESA:CVPR20,
author = {D.~Huynh and E.~Elhamifar},
title = {A Shared Multi-Attention Framework for Multi-Label Zero-Shot Learning},
journal = {{IEEE} Conference on Computer Vision and Pattern Recognition},
year = {2020}}