parent
fa4c19f938
commit
a1ff4ebd96
@ -0,0 +1,68 @@
|
||||
"""eval script"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from src import ipt
|
||||
from src.args import args
|
||||
from src.data.srdata import SRData
|
||||
from src.metrics import calc_psnr, quantize
|
||||
|
||||
from mindspore import context
|
||||
import mindspore.dataset as de
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
|
||||
|
||||
|
||||
def main():
|
||||
"""eval"""
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
||||
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
|
||||
train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
|
||||
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
|
||||
train_loader = train_de_dataset.create_dict_iterator()
|
||||
|
||||
net_m = ipt.IPT(args)
|
||||
print('load mindspore net successfully.')
|
||||
if args.pth_path:
|
||||
param_dict = load_checkpoint(args.pth_path)
|
||||
load_param_into_net(net_m, param_dict)
|
||||
net_m.set_train(False)
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
hr_np = np.float32(hr.asnumpy())
|
||||
pred = net_m.infrc(lr)
|
||||
pred_np = np.float32(pred.asnumpy())
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
|
||||
psnrs[batch_idx, 0] = psnr
|
||||
if args.denoise:
|
||||
print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
|
||||
elif args.derain:
|
||||
print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0)))
|
||||
else:
|
||||
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Start main function!")
|
||||
main()
|
After Width: | Height: | Size: 1.7 MiB |
@ -0,0 +1,26 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.vitm import ViT
|
||||
|
||||
|
||||
def IPT(*args, **kwargs):
|
||||
return ViT(*args, **kwargs)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'IPT':
|
||||
return IPT(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
DATA_SET=$3
|
||||
PATH_CHECKPOINT=$4
|
||||
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
|
||||
##denoise
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
||||
|
||||
##derain
|
||||
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
|
@ -0,0 +1,239 @@
|
||||
'''args'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
from src import template
|
||||
|
||||
parser = argparse.ArgumentParser(description='EDSR and MDSR')
|
||||
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Enables debug mode')
|
||||
parser.add_argument('--template', default='.',
|
||||
help='You can set various templates in option.py')
|
||||
|
||||
# Hardware specifications
|
||||
parser.add_argument('--n_threads', type=int, default=6,
|
||||
help='number of threads for data loading')
|
||||
parser.add_argument('--cpu', action='store_true',
|
||||
help='use cpu only')
|
||||
parser.add_argument('--n_GPUs', type=int, default=1,
|
||||
help='number of GPUs')
|
||||
parser.add_argument('--seed', type=int, default=1,
|
||||
help='random seed')
|
||||
|
||||
# Data specifications
|
||||
parser.add_argument('--dir_data', type=str, default='/cache/data/',
|
||||
help='dataset directory')
|
||||
parser.add_argument('--dir_demo', type=str, default='../test',
|
||||
help='demo image directory')
|
||||
parser.add_argument('--data_train', type=str, default='DIV2K',
|
||||
help='train dataset name')
|
||||
parser.add_argument('--data_test', type=str, default='DIV2K',
|
||||
help='test dataset name')
|
||||
parser.add_argument('--data_range', type=str, default='1-800/801-810',
|
||||
help='train/test data range')
|
||||
parser.add_argument('--ext', type=str, default='sep',
|
||||
help='dataset file extension')
|
||||
parser.add_argument('--scale', type=str, default='4',
|
||||
help='super resolution scale')
|
||||
parser.add_argument('--patch_size', type=int, default=48,
|
||||
help='output patch size')
|
||||
parser.add_argument('--rgb_range', type=int, default=255,
|
||||
help='maximum value of RGB')
|
||||
parser.add_argument('--n_colors', type=int, default=3,
|
||||
help='number of color channels to use')
|
||||
parser.add_argument('--chop', action='store_true',
|
||||
help='enable memory-efficient forward')
|
||||
parser.add_argument('--no_augment', action='store_true',
|
||||
help='do not use data augmentation')
|
||||
|
||||
# Model specifications
|
||||
parser.add_argument('--model', default='vtip',
|
||||
help='model name')
|
||||
|
||||
parser.add_argument('--act', type=str, default='relu',
|
||||
help='activation function')
|
||||
parser.add_argument('--pre_train', type=str, default='',
|
||||
help='pre-trained model directory')
|
||||
parser.add_argument('--extend', type=str, default='.',
|
||||
help='pre-trained model directory')
|
||||
parser.add_argument('--n_resblocks', type=int, default=16,
|
||||
help='number of residual blocks')
|
||||
parser.add_argument('--n_feats', type=int, default=64,
|
||||
help='number of feature maps')
|
||||
parser.add_argument('--res_scale', type=float, default=1,
|
||||
help='residual scaling')
|
||||
parser.add_argument('--shift_mean', default=True,
|
||||
help='subtract pixel mean from the input')
|
||||
parser.add_argument('--dilation', action='store_true',
|
||||
help='use dilated convolution')
|
||||
parser.add_argument('--precision', type=str, default='single',
|
||||
choices=('single', 'half'),
|
||||
help='FP precision for test (single | half)')
|
||||
|
||||
# Option for Residual dense network (RDN)
|
||||
parser.add_argument('--G0', type=int, default=64,
|
||||
help='default number of filters. (Use in RDN)')
|
||||
parser.add_argument('--RDNkSize', type=int, default=3,
|
||||
help='default kernel size. (Use in RDN)')
|
||||
parser.add_argument('--RDNconfig', type=str, default='B',
|
||||
help='parameters config of RDN. (Use in RDN)')
|
||||
|
||||
# Option for Residual channel attention network (RCAN)
|
||||
parser.add_argument('--n_resgroups', type=int, default=10,
|
||||
help='number of residual groups')
|
||||
parser.add_argument('--reduction', type=int, default=16,
|
||||
help='number of feature maps reduction')
|
||||
|
||||
# Training specifications
|
||||
parser.add_argument('--reset', action='store_true',
|
||||
help='reset the training')
|
||||
parser.add_argument('--test_every', type=int, default=1000,
|
||||
help='do test per every N batches')
|
||||
parser.add_argument('--epochs', type=int, default=300,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=16,
|
||||
help='input batch size for training')
|
||||
parser.add_argument('--test_batch_size', type=int, default=1,
|
||||
help='input batch size for training')
|
||||
parser.add_argument('--split_batch', type=int, default=1,
|
||||
help='split the batch into smaller chunks')
|
||||
parser.add_argument('--self_ensemble', action='store_true',
|
||||
help='use self-ensemble method for test')
|
||||
parser.add_argument('--test_only', action='store_true',
|
||||
help='set this option to test the model')
|
||||
parser.add_argument('--gan_k', type=int, default=1,
|
||||
help='k value for adversarial loss')
|
||||
|
||||
# Optimization specifications
|
||||
parser.add_argument('--lr', type=float, default=1e-4,
|
||||
help='learning rate')
|
||||
parser.add_argument('--decay', type=str, default='200',
|
||||
help='learning rate decay type')
|
||||
parser.add_argument('--gamma', type=float, default=0.5,
|
||||
help='learning rate decay factor for step decay')
|
||||
parser.add_argument('--optimizer', default='ADAM',
|
||||
choices=('SGD', 'ADAM', 'RMSprop'),
|
||||
help='optimizer to use (SGD | ADAM | RMSprop)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9,
|
||||
help='SGD momentum')
|
||||
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
|
||||
help='ADAM beta')
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8,
|
||||
help='ADAM epsilon for numerical stability')
|
||||
parser.add_argument('--weight_decay', type=float, default=0,
|
||||
help='weight decay')
|
||||
parser.add_argument('--gclip', type=float, default=0,
|
||||
help='gradient clipping threshold (0 = no clipping)')
|
||||
|
||||
# Loss specifications
|
||||
parser.add_argument('--loss', type=str, default='1*L1',
|
||||
help='loss function configuration')
|
||||
parser.add_argument('--skip_threshold', type=float, default='1e8',
|
||||
help='skipping batch that has large error')
|
||||
|
||||
# Log specifications
|
||||
parser.add_argument('--save', type=str, default='/cache/results/edsr_baseline_x2/',
|
||||
help='file name to save')
|
||||
parser.add_argument('--load', type=str, default='',
|
||||
help='file name to load')
|
||||
parser.add_argument('--resume', type=int, default=0,
|
||||
help='resume from specific checkpoint')
|
||||
parser.add_argument('--save_models', action='store_true',
|
||||
help='save all intermediate models')
|
||||
parser.add_argument('--print_every', type=int, default=100,
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--save_results', action='store_true',
|
||||
help='save output results')
|
||||
parser.add_argument('--save_gt', action='store_true',
|
||||
help='save low-resolution and high-resolution images together')
|
||||
|
||||
parser.add_argument('--scalelr', type=int, default=0)
|
||||
# cloud
|
||||
parser.add_argument('--moxfile', type=int, default=1)
|
||||
parser.add_argument('--imagenet', type=int, default=0)
|
||||
parser.add_argument('--data_url', type=str, help='path to dataset')
|
||||
parser.add_argument('--train_url', type=str, help='train_dir')
|
||||
parser.add_argument('--pretrain', type=str, default='')
|
||||
parser.add_argument('--pth_path', type=str, default='')
|
||||
parser.add_argument('--load_query', type=int, default=0)
|
||||
# transformer
|
||||
parser.add_argument('--patch_dim', type=int, default=3)
|
||||
parser.add_argument('--num_heads', type=int, default=12)
|
||||
parser.add_argument('--num_layers', type=int, default=12)
|
||||
parser.add_argument('--dropout_rate', type=float, default=0)
|
||||
parser.add_argument('--no_norm', action='store_true')
|
||||
parser.add_argument('--post_norm', action='store_true')
|
||||
parser.add_argument('--no_mlp', action='store_true')
|
||||
parser.add_argument('--test', action='store_true')
|
||||
parser.add_argument('--chop_new', action='store_true')
|
||||
parser.add_argument('--pos_every', action='store_true')
|
||||
parser.add_argument('--no_pos', action='store_true')
|
||||
parser.add_argument('--num_queries', type=int, default=6)
|
||||
parser.add_argument('--reweight', action='store_true')
|
||||
|
||||
# denoise
|
||||
parser.add_argument('--denoise', action='store_true')
|
||||
parser.add_argument('--sigma', type=float, default=25)
|
||||
|
||||
# derain
|
||||
parser.add_argument('--derain', action='store_true')
|
||||
parser.add_argument('--finetune', action='store_true')
|
||||
parser.add_argument('--derain_test', type=int, default=10)
|
||||
# alltask
|
||||
parser.add_argument('--alltask', action='store_true')
|
||||
|
||||
# dehaze
|
||||
parser.add_argument('--dehaze', action='store_true')
|
||||
parser.add_argument('--dehaze_test', type=int, default=100)
|
||||
parser.add_argument('--indoor', action='store_true')
|
||||
parser.add_argument('--outdoor', action='store_true')
|
||||
parser.add_argument('--nochange', action='store_true')
|
||||
# deblur
|
||||
parser.add_argument('--deblur', action='store_true')
|
||||
parser.add_argument('--deblur_test', type=int, default=1000)
|
||||
|
||||
# distribute
|
||||
parser.add_argument('--init_method', type=str,
|
||||
default=None, help='master address')
|
||||
parser.add_argument('--rank', type=int, default=0,
|
||||
help='Index of current task')
|
||||
parser.add_argument('--world_size', type=int, default=1,
|
||||
help='Total number of tasks')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--dist-url', default='', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--distribute', action='store_true')
|
||||
|
||||
args, unparsed = parser.parse_known_args()
|
||||
template.set_template(args)
|
||||
|
||||
args.scale = [int(x) for x in args.scale.split("+")]
|
||||
args.data_train = args.data_train.split('+')
|
||||
args.data_test = args.data_test.split('+')
|
||||
|
||||
if args.epochs == 0:
|
||||
args.epochs = 1e8
|
||||
|
||||
for arg in vars(args):
|
||||
if vars(args)[arg] == 'True':
|
||||
vars(args)[arg] = True
|
||||
elif vars(args)[arg] == 'False':
|
||||
vars(args)[arg] = False
|
@ -0,0 +1,35 @@
|
||||
"""data"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd"
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
class Data:
|
||||
"""data"""
|
||||
|
||||
def __init__(self, args):
|
||||
self.loader_train = None
|
||||
self.loader_test = []
|
||||
for d in args.data_test:
|
||||
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']:
|
||||
m = import_module('data.benchmark')
|
||||
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
|
||||
else:
|
||||
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
||||
m = import_module('data.' + module_name.lower())
|
||||
testset = getattr(m, module_name)(args, train=False, name=d)
|
||||
|
||||
self.loader_test.append(
|
||||
testset
|
||||
)
|
@ -0,0 +1,93 @@
|
||||
"""common"""
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import skimage.color as sc
|
||||
|
||||
|
||||
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
|
||||
"""common"""
|
||||
ih, iw = args[0].shape[:2]
|
||||
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
|
||||
if not input_large:
|
||||
tx, ty = scale * ix, scale * iy
|
||||
else:
|
||||
tx, ty = ix, iy
|
||||
|
||||
ret = [
|
||||
args[0][iy:iy + ip, ix:ix + ip, :],
|
||||
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def set_channel(*args, n_channels=3):
|
||||
"""common"""
|
||||
|
||||
def _set_channel(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
|
||||
c = img.shape[2]
|
||||
if n_channels == 1 and c == 3:
|
||||
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
||||
elif n_channels == 3 and c == 1:
|
||||
img = np.concatenate([img] * n_channels, 2)
|
||||
|
||||
return img[:, :, :n_channels]
|
||||
|
||||
return [_set_channel(a) for a in args]
|
||||
|
||||
|
||||
def np2Tensor(*args, rgb_range=255):
|
||||
"""common"""
|
||||
|
||||
def _np2Tensor(img):
|
||||
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
||||
tensor = np_transpose.astype(np.float32)
|
||||
tensor = tensor * (rgb_range / 255)
|
||||
# tensor = torch.from_numpy(np_transpose).float()
|
||||
# tensor.mul_(rgb_range / 255)
|
||||
|
||||
return tensor
|
||||
|
||||
return [_np2Tensor(a) for a in args]
|
||||
|
||||
|
||||
def augment(*args, hflip=True, rot=True):
|
||||
"""common"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(a) for a in args]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,241 @@
|
||||
'''stride'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class _stride_unfold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_stride_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
else:
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.unfold = _unfold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
"""stride"""
|
||||
N, C, H, W = x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
nh = int(H / self.stride)
|
||||
nw = int(W / self.stride)
|
||||
for i in range(nh):
|
||||
leftup_idx_x.append(i * self.stride)
|
||||
for i in range(nw):
|
||||
leftup_idx_y.append(i * self.stride)
|
||||
NumBlock_x = len(leftup_idx_x)
|
||||
NumBlock_y = len(leftup_idx_y)
|
||||
zeroslike = P.ZerosLike()
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
|
||||
NumBlock_y * self.kernel_size), mstype.float32)
|
||||
N, C, H, W = unf_x.shape
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
unf_i = i * self.kernel_size
|
||||
unf_j = j * self.kernel_size
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size,
|
||||
org_j:org_j + self.kernel_size]
|
||||
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2(
|
||||
(zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike(
|
||||
unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))),
|
||||
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
|
||||
y = self.unfold(unf_x)
|
||||
return y
|
||||
|
||||
|
||||
class _stride_fold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_stride_fold_, self).__init__()
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
self.kernel_size = [kernel_size, kernel_size]
|
||||
|
||||
if stride == -1:
|
||||
self.stride = kernel_size[0]
|
||||
else:
|
||||
self.stride = stride
|
||||
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.fold = _fold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
if self.output_shape[0] == -1:
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx = []
|
||||
for i in range(0, H, self.kernel_size[0]):
|
||||
leftup_idx.append(i)
|
||||
NumBlock = len(leftup_idx)
|
||||
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
|
||||
|
||||
for i in range(NumBlock):
|
||||
for j in range(NumBlock):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx[i]
|
||||
org_j = leftup_idx[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
|
||||
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
|
||||
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
|
||||
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
|
||||
y = fold_x
|
||||
else:
|
||||
NumBlock_x = int(
|
||||
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
|
||||
NumBlock_y = int(
|
||||
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
|
||||
large_shape = [NumBlock_x * self.kernel_size[0],
|
||||
NumBlock_y * self.kernel_size[1]]
|
||||
self.fold = _fold_(self.kernel_size, large_shape)
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
for i in range(NumBlock_x):
|
||||
leftup_idx_x.append(i * self.kernel_size[0])
|
||||
for i in range(NumBlock_y):
|
||||
leftup_idx_y.append(i * self.kernel_size[1])
|
||||
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
|
||||
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
|
||||
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
|
||||
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
|
||||
y = fold_x
|
||||
return y
|
||||
|
||||
|
||||
class _unfold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
N, C, H, W = x.shape
|
||||
numH = int(H / self.kernel_size)
|
||||
numW = int(W / self.kernel_size)
|
||||
if numH * self.kernel_size != H or numW * self.kernel_size != W:
|
||||
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
|
||||
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, C, int(
|
||||
numH * numW), self.kernel_size, self.kernel_size))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
|
||||
return output_img
|
||||
|
||||
|
||||
class _fold_(nn.Cell):
|
||||
'''stride'''
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_fold_, self).__init__()
|
||||
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
self.kernel_size = [kernel_size, kernel_size]
|
||||
|
||||
if stride == -1:
|
||||
self.stride = kernel_size[0]
|
||||
self.output_shape = output_shape
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
'''stride'''
|
||||
N, C, L = x.shape
|
||||
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
|
||||
if self.output_shape[0] == -1:
|
||||
numH = int(np.sqrt(C))
|
||||
numW = int(np.sqrt(C))
|
||||
org_H = int(numH * self.kernel_size[0])
|
||||
org_W = org_H
|
||||
else:
|
||||
org_H = int(self.output_shape[0])
|
||||
org_W = int(self.output_shape[1])
|
||||
numH = int(org_H / self.kernel_size[0])
|
||||
numW = int(org_W / self.kernel_size[1])
|
||||
|
||||
output_img = self.reshape(
|
||||
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
|
||||
output_img = self.reshape(
|
||||
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))
|
||||
|
||||
output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
|
||||
return output_img
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,56 @@
|
||||
'''metrics'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def quantize(img, rgb_range):
|
||||
'''metrics'''
|
||||
pixel_range = 255 / rgb_range
|
||||
img = np.multiply(img, pixel_range)
|
||||
img = np.clip(img, 0, 255)
|
||||
img = np.round(img) / pixel_range
|
||||
return img
|
||||
|
||||
|
||||
def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
|
||||
'''metrics'''
|
||||
hr = np.float32(hr)
|
||||
sr = np.float32(sr)
|
||||
diff = (sr - hr) / rgb_range
|
||||
gray_coeffs = np.array([65.738, 129.057, 25.064]
|
||||
).reshape((1, 3, 1, 1)) / 256
|
||||
diff = np.multiply(diff, gray_coeffs).sum(1)
|
||||
if hr.size == 1:
|
||||
return 0
|
||||
if scale != 1:
|
||||
shave = scale
|
||||
else:
|
||||
shave = scale + 6
|
||||
if scale == 1:
|
||||
valid = diff
|
||||
else:
|
||||
valid = diff[..., shave:-shave, shave:-shave]
|
||||
mse = np.mean(pow(valid, 2))
|
||||
return -10 * math.log10(mse)
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=True):
|
||||
'''metrics'''
|
||||
img.astype(np.float32)
|
||||
if y_only:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
return rlt
|
@ -0,0 +1,67 @@
|
||||
'''temp'''
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
def set_template(args):
|
||||
'''temp'''
|
||||
if args.template.find('jpeg') >= 0:
|
||||
args.data_train = 'DIV2K_jpeg'
|
||||
args.data_test = 'DIV2K_jpeg'
|
||||
args.epochs = 200
|
||||
args.decay = '100'
|
||||
|
||||
if args.template.find('EDSR_paper') >= 0:
|
||||
args.model = 'EDSR'
|
||||
args.n_resblocks = 32
|
||||
args.n_feats = 256
|
||||
args.res_scale = 0.1
|
||||
|
||||
if args.template.find('MDSR') >= 0:
|
||||
args.model = 'MDSR'
|
||||
args.patch_size = 48
|
||||
args.epochs = 650
|
||||
|
||||
if args.template.find('DDBPN') >= 0:
|
||||
args.model = 'DDBPN'
|
||||
args.patch_size = 128
|
||||
args.scale = '4'
|
||||
|
||||
args.data_test = 'Set5'
|
||||
|
||||
args.batch_size = 20
|
||||
args.epochs = 1000
|
||||
args.decay = '500'
|
||||
args.gamma = 0.1
|
||||
args.weight_decay = 1e-4
|
||||
|
||||
args.loss = '1*MSE'
|
||||
|
||||
if args.template.find('GAN') >= 0:
|
||||
args.epochs = 200
|
||||
args.lr = 5e-5
|
||||
args.decay = '150'
|
||||
|
||||
if args.template.find('RCAN') >= 0:
|
||||
args.model = 'RCAN'
|
||||
args.n_resgroups = 10
|
||||
args.n_resblocks = 20
|
||||
args.n_feats = 64
|
||||
args.chop = True
|
||||
|
||||
if args.template.find('VDSR') >= 0:
|
||||
args.model = 'VDSR'
|
||||
args.n_resblocks = 20
|
||||
args.n_feats = 64
|
||||
args.patch_size = 41
|
||||
args.lr = 1e-1
|
Loading…
Reference in new issue