This repository is an official implementation of the paper "Pre-Trained Image Processing Transformer" from CVPR 2021.
We study the low-level computer vision task (e.g., denoising, super-resolution and deraining) and develop a new pre-trained model, namely, image processing transformer (IPT). To maximally excavate the capability of transformer, we present to utilize the well-known ImageNet benchmark for generating a large amount of corrupted image pairs. The IPT model is trained on these images with multi-heads and multi-tails. In addition, the contrastive learning is introduced for well adapting to different image processing tasks. The pre-trained model can therefore efficiently employed on desired task after fine-tuning. With only one pre-trained model, IPT outperforms the current state-of-the-art methods on various low-level benchmarks.
If you find our work useful in your research or publication, please cite our work:
author={Chen, Hanting and Wang, Yunhe and Guo, Tianyu and Xu, Chang and Deng, Yiping and Liu, Zhenhua and Ma, Siwei and Xu, Chunjing and Xu, Chao and Gao, Wen},
booktitle={CVPR},
year={2021}
}
## Model architecture
### The overall network architecture of IPT is shown as below:
> This is the inference script of IPT, you can following steps to finish the test of image processing tasks, like SR, denoise and derain, via the corresponding pretrained models.
### Scripts and Sample Code
```
IPT
├── eval.py # inference entry
├── image
│ └── ipt.png # the illustration of IPT network
├── model
│ ├── IPT_denoise30.ckpt # denoise model weights for noise level 30
│ ├── IPT_denoise50.ckpt # denoise model weights for noise level 50
│ ├── IPT_derain.ckpt # derain model weights
│ ├── IPT_sr2.ckpt # X2 super-resolution model weights
│ ├── IPT_sr3.ckpt # X3 super-resolution model weights
│ └── IPT_sr4.ckpt # X4 super-resolution model weights
├── readme.md # Readme
├── scripts
│ └── run_eval.sh # inference script for all tasks
└── src
├── args.py # options/hyper-parameters of IPT
├── data
│ ├── common.py # common dataset
│ ├── __init__.py # Class data init function
│ └── srdata.py # flow of loading sr data
├── foldunfold_stride.py # function of fold and unfold operations for images
├── metrics.py # PSNR calculator
├── template.py # setting of model selection
└── vitm.py # IPT network
```
### Script Parameter
> For details about hyperparameters, see src/args.py.