Skip to main content

Token Labeling Toolbox for training image models

Project description

All Tokens Matter: Token Labeling for Training Better Vision Transformers (arxiv)

This is a Pytorch implementation of our paper.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Our codes are based on the pytorch-image-models by Ross Wightman.

Update

2021.7: Add script to generate label data.

2021.6: Support pip install tlt to use our Token Labeling Toolbox for image models.

2021.6: Release training code and segmentation model.

2021.4: Release LV-ViT models.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-M 20 512 448 56.13M 85.5 link
LV-ViT-L 24 768 448 150.47M 86.2 link
LV-ViT-L 24 768 512 150.66M 86.4 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml scipy timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map in Google Drive and BaiDu Yun (password: y6j2). As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Train the LV-ViT-S:

If only 4 GPUs are available,

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 /path/to/imagenet --model lvvit_s -b 256 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

If 8 GPUs are available:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

Train the LV-ViT-M and LV-ViT-L (run on 8 GPUs):

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_m -b 128 --apex-amp --img-size 224 --drop-path 0.2 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_l -b 128 --lr 1.e-3 --aa rand-n3-m9-mstd0.5-inc1 --apex-amp --img-size 224 --drop-path 0.3 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

If you want to train our LV-ViT on images with 384x384 resolution, please use --img-size 384 --token-label-size 24.

Fine-tuning

To Fine-tune the pre-trained LV-ViT-S on images with 384x384 resolution:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 64 --apex-amp --img-size 384 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 24 --lr 5.e-6 --min-lr 5.e-6 --weight-decay 1.e-8 --finetune /path/to/checkpoint

To Fine-tune the pre-trained LV-ViT-S on other datasets without token labeling:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/dataset --model lvvit_s -b 64 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes $NUM_CLASSES --finetune /path/to/checkpoint

Segmentation

Our Segmentation model are fully based upon the MMSegmentation Toolkit. The model and config files are under seg/ folder which follow the same folder structure. You can simply drop in these file to get start.

git clone https://github.com/open-mmlab/mmsegmentation # and install

cp seg/mmseg/models/backbones/vit.py mmsegmentation/mmseg/models/backbones/
cp -r seg/configs/lvvit mmsegmentation/configs/

# test upernet+lvvit_s (add --aug-test to test on multi scale)
cd mmsegmentation
./tools/dist_test.sh configs/lvvit/upernet_lvvit_s_512x512_160k_ade20k.py /path/to/checkpoint 8 --eval mIoU [--aug-test]
Backbone Method Crop size Lr Schd mIoU mIoU(ms) Pixel Acc. Param Download
LV-ViT-S UperNet 512x512 160k 47.9 48.6 83.1 44M link
LV-ViT-M UperNet 512x512 160k 49.4 50.6 83.5 77M link
LV-ViT-L UperNet 512x512 160k 50.9 51.8 84.1 209M link

Visualization

We apply the visualization method in this repo to visualize the parts of the image that led to a certain classification for DeiT-Base and our LV-ViT-S. The parts of the image that used by the network to make the decision are highlighted in red.

Compare

Label generation

To generate token label data for training:

python3 generate_label.py /path/to/imagenet/train /path/to/save/label_top5_train_nfnet --model dm_nfnet_f6 --pretrained --img-size 576 -b 32 --crop-pct 1.0

Reference

If you use this repo or find it useful, please consider citing:

@article{jiang2021all,
  title={All Tokens Matter: Token Labeling for Training Better Vision Transformers},
  author={Jiang, Zihang and Hou, Qibin and Yuan, Li and Zhou, Daquan and Shi, Yujun and Jin, Xiaojie and Wang, Anran and Feng, Jiashi},
  journal={arXiv preprint arXiv:2104.10858},
  year={2021}
}

Related projects

T2T-ViT, Re-labeling ImageNet, MMSegmentation, Transformer Explainability.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tlt-0.2.0.tar.gz (33.3 kB view hashes)

Uploaded Source

Built Distribution

tlt-0.2.0-py3-none-any.whl (33.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page