!8711 Add CycleGAN in modelzoo

From: @zhao_ting_v
Reviewed-by: 
Signed-off-by:
pull/8711/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 31174b013e

@ -59,6 +59,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md)
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
- [CycleGAN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)

@ -0,0 +1,235 @@
# Contents
- [CycleGAN Description](#cyclegan-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Knowledge Distillation Process](#knowledge-distillation-process)
- [Prediction Process](#prediction-process)
- [Evaluation with cityscape dataset](#evaluation-with-cityscape-dataset)
- [Export MindIR](#export-mindir)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CycleGAN Description](#contents)
Generative Adversarial Network (referred to as GAN) is an unsupervised learning method that learns by letting two neural networks play against each other. CycleGAN is a kind of GAN, which consists of two generation networks and two discriminant networks. It converts a certain type of pictures into another type of pictures through unpaired pictures, which can be used for style transfer.
[Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017.
# [Model Architecture](#contents)
The CycleGAN contains two generation networks and two discriminant networks. We support two architectures for generation networks: resnet and unet. Resnet architecture contains three convolutions, several residual blocks, two fractionally-strided convlutions with stride 1/2, and one convolution that maps features to RGB. Unet architecture contains three unet block to downsample and upsample, several unet blocks unet block and one convolution that maps features to RGB. For the discriminator networks we use 70 × 70 PatchGANs, which aim to classify whether 70 × 70 overlapping image patches are real or fake.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [CityScape](<https://cityscapes-dataset.com>)
Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. We provide `src/utils/prepare_cityscapes_dataset.py` to process images. gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
The processed images will be placed at --output_dir.
Example usage:
```bash
python src/utils/prepare_cityscapes_dataset.py --gitFine_dir ./cityscapes/gtFine/ --leftImg8bit_dir ./cityscapes/leftImg8bit --output_dir ./cityscapes/
```
The directory structure is as follows:
```path
.
└─cityscapes
├─trainA
├─trainB
├─testA
└─testB
```
# [Environment Requirements](#contents)
- Hardware GPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```path
.
└─ cv
└─ cyclegan
├─ src
├─ __init__.py # init file
├─ dataset
├─ __init__.py # init file
├─ cyclegan_dataset.py # create cyclegan dataset
├─ datasets.py # UnalignedDataset and ImageFolderDataset class and some image utils
└─ distributed_sampler.py # iterator of dataset
├─ models
├─ __init__.py # init file
├─ cycle_gan.py # cyclegan model define
├─ losses.py # cyclegan losses function define
├─ networks.py # cyclegan sub networks define
├─ resnet.py # resnet generate network
└─ unet.py # unet generate network
└─ utils
├─ __init__.py # init file
├─ args.py # parse args
├─ prepare_cityscapes_dataset.py # prepare cityscapes dataset to cyclegan format
├─ cityscapes_utils.py # cityscapes dataset evaluation utils
├─ reporter.py # Reporter class
└─ tools.py # utils for cyclegan
├─ cityscape_eval.py # cityscape dataset eval script
├─ predict.py # generate images from A->B and B->A
├─ train.py # train script
├─ export.py # export mindir script
├─ README.md # descriptions about CycleGAN
└─ mindspore_hub_conf.py # mindspore hub interface
```
## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py as follows:
"model": "resnet" # generator model, should be in [resnet, unet].
"platform": "GPU" # run platform, support GPU, CPU and Ascend.
"device_id": 0 # device id, default is 0.
"lr": 0.0002 # init learning rate, default is 0.0002.
"pool_size": 50 # the size of image buffer that stores previously generated images, default is 50.
"lr_policy": "linear" # learning rate policy, default is linear.
"image_size": 256 # input image_size, default is 256.
"batch_size": 1 # batch_size, default is 1.
"max_epoch": 200 # epoch size for training, default is 200.
"n_epochs": 100 # number of epochs with the initial learning rate, default is 100
"beta1": 0.5 # Adam beta1, default is 0.5.
"init_type": normal # network initialization, default is normal.
"init_gain": 0.02 # scaling factor for normal, xavier and orthogonal, default is 0.02.
"in_planes": 3 # input channels, default is 3.
"ngf": 64 # generator model filter numbers, default is 64.
"gl_num": 9 # generator model residual block numbers, default is 9.
"ndf": 64 # discriminator model filter numbers, default is 64.
"dl_num": 3 # discriminator model residual block numbers, default is 3.
"slope": 0.2 # leakyrelu slope, default is 0.2.
"norm_mode":"instance" # norm mode, should be [batch, instance], default is instance.
"lambda_A": 10 # weight for cycle loss (A -> B -> A), default is 10.
"lambda_B": 10 # weight for cycle loss (B -> A -> B), default is 10.
"lambda_idt": 0.5 # if lambda_idt > 0 use identity mapping.
"gan_mode": lsgan # the type of GAN loss, should be [lsgan, vanilla], default is lsgan.
"pad_mode": REFLECT # the type of Pad, should be [CONSTANT, REFLECT, SYMMETRIC], default is REFLECT.
"need_dropout": True # whether need dropout, default is True.
"kd": False # knowledge distillation learning or not, default is False.
"t_ngf": 64 # teacher network generator model filter numbers when `kd` is True, default is 64.
"t_gl_num":9 # teacher network generator model residual block numbers when `kd` is True, default is 9.
"t_slope": 0.2 # teacher network leakyrelu slope when `kd` is True, default is 0.2.
"t_norm_mode": "instance" #teacher network norm mode when `kd` is True, defaultis instance.
"print_iter": 100 # log print iter, default is 100.
"outputs_dir": "outputs" # models are saved here, default is ./outputs.
"dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc).
"save_imgs": True # whether save imgs when epoch end, if True result images will generate in `outputs_dir/imgs`, default is True.
"GT_A_ckpt": None # teacher network pretrained checkpoint file path of G_A when `kd` is True.
"GT_B_ckpt": None # teacher network pretrained checkpoint file path of G_B when `kd` is True.
"G_A_ckpt": None # pretrained checkpoint file path of G_A.
"G_B_ckpt": None # pretrained checkpoint file path of G_B.
"D_A_ckpt": None # pretrained checkpoint file path of D_A.
"D_B_ckpt": None # pretrained checkpoint file path of D_B.
```
## [Training Process](#contents)
```bash
python train.py --platform [PLATFORM] --dataroot [DATA_PATH]
```
**Note: pad_mode should be CONSTANT when use Ascend and CPU. When using unet as generate network, the gl_num should less than 7.**
## [Knowledge Distillation Process](#contents)
```bash
python train.py --platform [PLATFORM] --dataroot [DATA_PATH] --ngf [NGF] --kd True --GT_A_ckpt [G_A_CKPT] --GT_B_ckpt [G_B_CKPT]
```
**Note: the student network ngf should be 1/2 or 1/4 of teacher network ngf, if you change default args when training teacher generate networks, please change t_xx in knowledge distillation process.**
## [Prediction Process](#contents)
```bash
python predict.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT]
```
**Note: the result will saved at `outputs_dir/predict`.**
## [Evaluation with cityscape dataset](#contents)
```bash
python cityscape_eval.py --cityscapes_dir [LABEL_PATH] --result_dir [FAKEB_PATH]
```
**Note: Please run cityscape_eval.py after prediction process.**
## [Export MindIR](#contents)
```bash
python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
**Note: The file_name parameter is the prefix, the final file will as [FILE_NAME]_AtoB.[FILE_FORMAT] and [FILE_NAME]_BtoA.[FILE_FORMAT].**
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | GPU |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | CycleGAN |
| Resource | NV SMX2 V100-32G |
| uploaded Date | 12/10/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Cityscapes |
| Training Parameters | epoch=200, steps=2975, batch_size=1, lr=0.002 |
| Optimizer | Adam |
| Loss Function | Mean Sqare Loss & L1 Loss |
| outputs | probability |
| Speed | 1pc: 264 ms/step; |
| Total time | 1pc: 43.6h; |
| Parameters (M) | 11.378 M |
| Checkpoint for Fine tuning | 44M (.ckpt file) |
| Scripts | [CycleGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan) |
### Inference Performance
| Parameters | GPU |
| ------------------- | --------------------------- |
| Model Version | CycleGAN |
| Resource | GPU |
| Uploaded Date | 12/10/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Cityscapes |
| batch_size | 1 |
| outputs | probability |
| Accuracy | mean_pixel_acc: 54.8, mean_class_acc: 21.3, mean_class_iou: 16.1 |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval use cityscape dataset."""
import os
import argparse
import numpy as np
from src.dataset import make_dataset
from src.utils import CityScapes, fast_hist, get_scores
parser = argparse.ArgumentParser()
parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset")
parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated")
args = parser.parse_args()
def main():
CS = CityScapes()
cityscapes = make_dataset(args.cityscapes_dir)
hist_perframe = np.zeros((CS.class_num, CS.class_num))
for i, img_path in enumerate(cityscapes):
if i % 100 == 0:
print('Evaluating: %d/%d' % (i, len(cityscapes)))
img_name = os.path.split(img_path)[1]
ids1 = CS.get_id(os.path.join(args.cityscapes_dir, img_name))
ids2 = CS.get_id(os.path.join(args.result_dir, img_name))
hist_perframe += fast_hist(ids1.flatten(), ids2.flatten(), CS.class_num)
mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
print(f"mean_pixel_acc: {mean_pixel_acc}, mean_class_acc: {mean_class_acc}, mean_class_iou: {mean_class_iou}")
with open('./evaluation_results.txt', 'w') as f:
f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc)
f.write('Mean class accuracy: %f\n' % mean_class_acc)
f.write('Mean class IoU: %f\n' % mean_class_iou)
f.write('************ Per class numbers below ************\n')
for i, cl in enumerate(CS.classes):
while len(cl) < 15:
cl = cl + ' '
f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i]))
if __name__ == '__main__':
main()

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export file."""
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export
from src.models import get_generator
from src.utils import get_args, load_ckpt
args = get_args("export")
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
if __name__ == '__main__':
G_A = get_generator(args)
G_B = get_generator(args)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G_A.set_train(True)
G_B.set_train(True)
load_ckpt(args, G_A, G_B)
input_shp = [1, 3, args.image_size, args.image_size]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
G_A_file = f"{args.file_name}_BtoA"
export(G_A, input_array, file_name=G_A_file, file_format=args.file_format)
G_B_file = f"{args.file_name}_AtoB"
export(G_B, input_array, file_name=G_B_file, file_format=args.file_format)

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.models import get_generator
def create_network(name, *args, **kwargs):
if name == "cyclegan":
G_A = get_generator(*args, **kwargs)
G_B = get_generator(*args, **kwargs)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G_A.set_train(True)
G_B.set_train(True)
return G_A, G_B
raise NotImplementedError(f"{name} is not implemented in the repo")

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN predict."""
import os
from mindspore import Tensor
from src.models import get_generator
from src.utils import get_args, load_ckpt, save_image, Reporter
from src.dataset import create_dataset
def predict():
"""Predict function."""
args = get_args("predict")
G_A = get_generator(args)
G_B = get_generator(args)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G_A.set_train(True)
G_B.set_train(True)
load_ckpt(args, G_A, G_B)
imgs_out = os.path.join(args.outputs_dir, "predict")
if not os.path.exists(imgs_out):
os.makedirs(imgs_out)
if not os.path.exists(os.path.join(imgs_out, "fake_A")):
os.makedirs(os.path.join(imgs_out, "fake_A"))
if not os.path.exists(os.path.join(imgs_out, "fake_B")):
os.makedirs(os.path.join(imgs_out, "fake_B"))
args.data_dir = 'testA'
ds = create_dataset(args)
reporter = Reporter(args)
reporter.start_predict("A to B")
for data in ds.create_dict_iterator(output_numpy=True):
img_A = Tensor(data["image"])
path_A = str(data["image_name"][0], encoding="utf-8")
fake_B = G_A(img_A)
save_image(fake_B, os.path.join(imgs_out, "fake_B", path_A))
reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A))
reporter.end_predict()
args.data_dir = 'testB'
ds = create_dataset(args)
reporter.dataset_size = args.dataset_size
reporter.start_predict("B to A")
for data in ds.create_dict_iterator(output_numpy=True):
img_B = Tensor(data["image"])
path_B = str(data["image_name"][0], encoding="utf-8")
fake_A = G_B(img_B)
save_image(fake_A, os.path.join(imgs_out, "fake_A", path_B))
reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B))
reporter.end_predict()
if __name__ == "__main__":
predict()

@ -0,0 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""
from .datasets import UnalignedDataset, ImageFolderDataset, make_dataset
from .cyclegan_dataset import create_dataset

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN dataset."""
import os
import multiprocessing
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from .distributed_sampler import DistributedSampler
from .datasets import UnalignedDataset, ImageFolderDataset
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")):
"""Create dataset"""
dataroot = args.dataroot
phase = args.phase
batch_size = args.batch_size
device_num = args.device_num
rank = args.rank
cores = multiprocessing.cpu_count()
num_parallel_workers = min(8, int(cores / device_num))
image_size = args.image_size
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
if phase == "train":
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size)
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
trans = [
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(1)
else:
datadir = os.path.join(dataroot, args.data_dir)
dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"],
num_parallel_workers=num_parallel_workers)
trans = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(1, drop_remainder=True)
ds = ds.repeat(1)
args.dataset_size = len(dataset)
return ds

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN datasets."""
import os
import random
import numpy as np
from PIL import Image
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
def is_image_file(filename):
"""Judge whether it is a picture."""
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir_path, max_dataset_size=float("inf")):
"""Return image list in dir."""
images = []
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
for root, _, fnames in sorted(os.walk(dir_path)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
class UnalignedDataset:
"""
This dataset class can load unaligned/unpaired datasets.
Args:
dataroot (str): Images root directory.
phase (str): Train or test. It requires two directories in dataroot, like trainA and trainB to
host training images from domain A '{dataroot}/trainA' and from domain B '{dataroot}/trainB' respectively.
max_dataset_size (int): Maximum number of return image paths.
Returns:
Two domain image path list.
"""
def __init__(self, dataroot, phase, max_dataset_size=float("inf")):
self.dir_A = os.path.join(dataroot, phase + 'A')
self.dir_B = os.path.join(dataroot, phase + 'B')
self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
def __getitem__(self, index):
if index % max(self.A_size, self.B_size) == 0:
random.shuffle(self.A_paths)
A_path = self.A_paths[index % self.A_size]
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = np.array(Image.open(A_path).convert('RGB'))
B_img = np.array(Image.open(B_path).convert('RGB'))
return A_img, B_img
def __len__(self):
return max(self.A_size, self.B_size)
class ImageFolderDataset:
"""
This dataset class can load images from image folder.
Args:
dataroot (str): Images root directory.
max_dataset_size (int): Maximum number of return image paths.
Returns:
Image path list.
"""
def __init__(self, dataroot, max_dataset_size=float("inf")):
self.dataroot = dataroot
self.paths = sorted(make_dataset(dataroot, max_dataset_size))
self.size = len(self.paths)
def __getitem__(self, index):
img_path = self.paths[index % self.size]
img = np.array(Image.open(img_path).convert('RGB'))
return img, os.path.split(img_path)[1]
def __len__(self):
return self.size

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

@ -0,0 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""
from .cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD
from .losses import DiscriminatorLoss, GeneratorLoss, GANLoss
from .networks import init_weights

File diff suppressed because it is too large Load Diff

@ -0,0 +1,175 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN losses"""
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from .cycle_gan import get_generator
from ..utils import load_teacher_ckpt
class BCEWithLogits(nn.Cell):
"""
BCEWithLogits creates a criterion to measure the Binary Cross Entropy between the true labels and
predicted labels with sigmoid logits.
Args:
reduction (str): Specifies the reduction to be applied to the output.
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
Outputs:
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
Otherwise, the output is a scalar.
"""
def __init__(self, reduction='mean'):
super(BCEWithLogits, self).__init__()
if reduction is None:
reduction = 'none'
if reduction not in ('mean', 'sum', 'none'):
raise ValueError(f"reduction method for {reduction.lower()} is not supported")
self.loss = ops.SigmoidCrossEntropyWithLogits()
self.reduce = False
if reduction == 'sum':
self.reduce_mode = ops.ReduceSum()
self.reduce = True
elif reduction == 'mean':
self.reduce_mode = ops.ReduceMean()
self.reduce = True
def construct(self, predict, target):
loss = self.loss(predict, target)
if self.reduce:
loss = self.reduce_mode(loss)
return loss
class GANLoss(nn.Cell):
"""
Cycle GAN loss factory.
Args:
mode (str): The type of GAN objective. It currently supports 'vanilla', 'lsgan'. Default: 'lsgan'.
reduction (str): Specifies the reduction to be applied to the output.
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
Outputs:
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
Otherwise, the output is a scalar.
"""
def __init__(self, mode="lsgan", reduction='mean'):
super(GANLoss, self).__init__()
self.loss = None
self.ones = ops.OnesLike()
if mode == "lsgan":
self.loss = nn.MSELoss(reduction)
elif mode == "vanilla":
self.loss = BCEWithLogits(reduction)
else:
raise NotImplementedError(f'GANLoss {mode} not recognized, we support lsgan and vanilla.')
def construct(self, predict, target):
target = ops.cast(target, ops.dtype(predict))
target = self.ones(predict) * target
loss = self.loss(predict, target)
return loss
class GeneratorLoss(nn.Cell):
"""
Cycle GAN generator loss.
Args:
args (class): Option class.
generator (Cell): Generator of CycleGAN.
D_A (Cell): The discriminator network of domain A to domain B.
D_B (Cell): The discriminator network of domain B to domain A.
Outputs:
Tuple Tensor, the losses of generator.
"""
def __init__(self, args, generator, D_A, D_B):
super(GeneratorLoss, self).__init__()
self.lambda_A = args.lambda_A
self.lambda_B = args.lambda_B
self.lambda_idt = args.lambda_idt
self.use_identity = args.lambda_idt > 0
self.dis_loss = GANLoss(args.gan_mode)
self.rec_loss = nn.L1Loss("mean")
self.generator = generator
self.D_A = D_A
self.D_B = D_B
self.true = Tensor(True, mstype.bool_)
self.kd = args.kd
if self.kd:
self.GT_A = get_generator(args, True)
load_teacher_ckpt(self.GT_A, args.GT_A_ckpt, "GT_A", "G_A")
self.GT_B = get_generator(args, True)
load_teacher_ckpt(self.GT_B, args.GT_B_ckpt, "GT_B", "G_B")
self.GT_A.set_train(True)
self.GT_B.set_train(True)
def construct(self, img_A, img_B):
"""If use_identity, identity loss will be used."""
fake_A, fake_B, rec_A, rec_B, identity_A, identity_B = self.generator(img_A, img_B)
loss_G_A = self.dis_loss(self.D_B(fake_B), self.true)
loss_G_B = self.dis_loss(self.D_A(fake_A), self.true)
loss_C_A = self.rec_loss(rec_A, img_A) * self.lambda_A
loss_C_B = self.rec_loss(rec_B, img_B) * self.lambda_B
if self.use_identity:
loss_idt_A = self.rec_loss(identity_A, img_A) * self.lambda_A * self.lambda_idt
loss_idt_B = self.rec_loss(identity_B, img_B) * self.lambda_B * self.lambda_idt
else:
loss_idt_A = 0
loss_idt_B = 0
loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B
if self.kd:
teacher_A = self.GT_B(img_B)
teacher_B = self.GT_A(img_A)
kd_loss_A = self.rec_loss(teacher_A, fake_A) * self.lambda_A * 5
kd_loss_B = self.rec_loss(teacher_B, fake_B) * self.lambda_A * 5
loss_G += kd_loss_A + kd_loss_B
return (fake_A, fake_B, loss_G, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B)
class DiscriminatorLoss(nn.Cell):
"""
Cycle GAN discriminator loss.
Args:
args (class): option class.
D_A (Cell): The discriminator network of domain A to domain B.
D_B (Cell): The discriminator network of domain B to domain A.
Outputs:
Tuple Tensor, the loss of discriminator.
"""
def __init__(self, args, D_A, D_B):
super(DiscriminatorLoss, self).__init__()
self.D_A = D_A
self.D_B = D_B
self.false = Tensor(False, mstype.bool_)
self.true = Tensor(True, mstype.bool_)
self.dis_loss = GANLoss(args.gan_mode)
self.rec_loss = nn.L1Loss("mean")
def construct(self, img_A, img_B, fake_A, fake_B):
D_fake_A = self.D_A(fake_A)
D_img_A = self.D_A(img_A)
D_fake_B = self.D_B(fake_B)
D_img_B = self.D_B(img_B)
loss_D_A = self.dis_loss(D_fake_A, self.false) + self.dis_loss(D_img_A, self.true)
loss_D_B = self.dis_loss(D_fake_B, self.false) + self.dis_loss(D_img_B, self.true)
loss_D = (loss_D_A + loss_D_B) * 0.5
return loss_D

@ -0,0 +1,156 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN network."""
import mindspore.nn as nn
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain)))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain)))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class ConvNormReLU(nn.Cell):
"""
Convolution fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size. Default: 4.
stride (int): Stride size for the first convolutional layer. Default: 2.
alpha (float): Slope of LackyReLU. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
use_relu (bool): Use relu or not. Default: True.
padding (int): Pad size, if it is None, it will calculate by kernel_size. Default: None.
Returns:
Tensor, output tensor.
"""
def __init__(self,
in_planes,
out_planes,
kernel_size=4,
stride=2,
alpha=0.2,
norm_mode='batch',
pad_mode='CONSTANT',
use_relu=True,
padding=None):
super(ConvNormReLU, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if padding is None:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
has_bias=has_bias, padding=padding)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class ConvTransposeNormReLU(nn.Cell):
"""
ConvTranspose2d fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size. Default: 4.
stride (int): Stride size for the first convolutional layer. Default: 2.
alpha (float): Slope of LackyReLU. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
use_relu (bool): use relu or not. Default: True.
padding (int): pad size, if it is None, it will calculate by kernel_size. Default: None.
Returns:
Tensor, output tensor.
"""
def __init__(self,
in_planes,
out_planes,
kernel_size=4,
stride=2,
alpha=0.2,
norm_mode='batch',
pad_mode='CONSTANT',
use_relu=True,
padding=None):
super(ConvTransposeNormReLU, self).__init__()
conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride=stride, pad_mode='same')
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if padding is None:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, pad_mode='same', has_bias=has_bias)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output

@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet Generator."""
import mindspore.nn as nn
import mindspore.ops as ops
from .networks import ConvNormReLU, ConvTransposeNormReLU
class ResidualBlock(nn.Cell):
"""
ResNet residual block definition.
Args:
dim (int): Input and output channel.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
dropout (bool): Use dropout or not. Default: False.
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Returns:
Tensor, output tensor.
"""
def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"):
super(ResidualBlock, self).__init__()
self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
self.dropout = dropout
if dropout:
self.dropout = nn.Dropout(0.5)
def construct(self, x):
out = self.conv1(x)
if self.dropout:
out = self.dropout(out)
out = self.conv2(out)
return x + out
class ResNetGenerator(nn.Cell):
"""
ResNet Generator of GAN.
Args:
in_planes (int): Input channel.
ngf (int): Output channel.
n_layers (int): The number of ConvNormReLU blocks.
alpha (float): LeakyRelu slope. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
dropout (bool): Use dropout or not. Default: False.
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Returns:
Tensor, output tensor.
"""
def __init__(self, in_planes=3, ngf=64, n_layers=9, alpha=0.2, norm_mode='batch', dropout=False,
pad_mode="CONSTANT"):
super(ResNetGenerator, self).__init__()
self.conv_in = ConvNormReLU(in_planes, ngf, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
self.down_1 = ConvNormReLU(ngf, ngf * 2, 3, 2, alpha, norm_mode)
self.down_2 = ConvNormReLU(ngf * 2, ngf * 4, 3, 2, alpha, norm_mode)
layers = [ResidualBlock(ngf * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
self.residuals = nn.SequentialCell(layers)
self.up_2 = ConvTransposeNormReLU(ngf * 4, ngf * 2, 3, 2, alpha, norm_mode)
self.up_1 = ConvTransposeNormReLU(ngf * 2, ngf, 3, 2, alpha, norm_mode)
if pad_mode == "CONSTANT":
self.conv_out = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad', padding=3)
else:
pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
conv = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad')
self.conv_out = nn.SequentialCell([pad, conv])
self.activate = ops.Tanh()
def construct(self, x):
x = self.conv_in(x)
x = self.down_1(x)
x = self.down_2(x)
x = self.residuals(x)
x = self.up_2(x)
x = self.up_1(x)
output = self.conv_out(x)
return self.activate(output)

@ -0,0 +1,124 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UNet Generator."""
import mindspore.nn as nn
import mindspore.ops as ops
class UnetGenerator(nn.Cell):
"""
Unet-based generator.
Args:
in_planes (int): the number of channels in input images.
out_planes (int): the number of channels in output images.
ngf (int): the number of filters in the last conv layer.
n_layers (int): the number of downsamplings in UNet.
alpha (float): LeakyRelu slope. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
dropout (bool): Use dropout or not. Default: False.
Returns:
Tensor, output tensor.
"""
def __init__(self, in_planes, out_planes, ngf=64, n_layers=7, alpha=0.2, norm_mode='bn', dropout=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
norm_mode=norm_mode, innermost=True)
for _ in range(n_layers - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode, dropout=dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
self.model = UnetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
outermost=True, norm_mode=norm_mode)
def construct(self, x):
return self.model(x)
class UnetSkipConnectionBlock(nn.Cell):
"""Unet submodule with skip connection.
Args:
outer_nc (int): The number of filters in the outer conv layer
inner_nc (int): The number of filters in the inner conv layer
in_planes (int): The number of channels in input images/features
dropout (bool): Use dropout or not. Default: False.
submodule (Cell): Previously defined submodules
outermost (bool): If this module is the outermost module
innermost (bool): If this module is the innermost module
alpha (float): LeakyRelu slope. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
Returns:
Tensor, output tensor.
"""
def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
super(UnetSkipConnectionBlock, self).__init__()
downnorm = nn.BatchNorm2d(inner_nc)
upnorm = nn.BatchNorm2d(outer_nc)
use_bias = False
if norm_mode == 'instance':
downnorm = nn.BatchNorm2d(inner_nc, affine=False)
upnorm = nn.BatchNorm2d(outer_nc, affine=False)
use_bias = True
if in_planes is None:
in_planes = outer_nc
downconv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
downrelu = nn.LeakyReLU(alpha)
uprelu = nn.ReLU()
if outermost:
upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, pad_mode='pad')
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.Conv2dTranspose(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
model = down + [submodule] + up
if dropout:
model.append(nn.Dropout(0.5))
self.model = nn.SequentialCell(model)
self.skip_connections = not outermost
self.concat = ops.Concat(axis=1)
def construct(self, x):
out = self.model(x)
if self.skip_connections:
out = self.concat((out, x))
return out

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""
from .args import get_args
from .reporter import Reporter
from .tools import get_lr, load_teacher_ckpt, ImagePool, load_ckpt, save_image
from .cityscapes_utils import CityScapes, fast_hist, get_scores

@ -0,0 +1,145 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get args."""
import argparse
import ast
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.communication.management import init, get_rank
def get_args(phase):
"""Define the common options that are used in both training and test."""
parser = argparse.ArgumentParser(description='Cycle GAN.')
# basic parameters
parser.add_argument('--model', type=str, default="resnet", choices=("resnet", "unet"), \
help='generator model, should be in [resnet, unet].')
parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend')
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate, default is 0.0002.")
parser.add_argument('--pool_size', type=int, default=50, \
help='the size of image buffer that stores previously generated images, default is 50.')
parser.add_argument('--lr_policy', type=str, default='linear', choices=("linear", "constant"), \
help='learning rate policy, default is linear')
parser.add_argument("--image_size", type=int, default=256, help="input image_size, default is 256.")
parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.')
parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
parser.add_argument('--n_epochs', type=int, default=100, \
help='number of epochs with the initial learning rate, default is 100')
parser.add_argument("--beta1", type=float, default=0.5, help="Adam beta1, default is 0.5.")
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"), \
help='network initialization, default is normal.')
parser.add_argument('--init_gain', type=float, default=0.02, \
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
# model parameters
parser.add_argument('--in_planes', type=int, default=3, help='input channels, default is 3.')
parser.add_argument('--ngf', type=int, default=64, help='generator model filter numbers, default is 64.')
parser.add_argument('--gl_num', type=int, default=9, help='generator model residual block numbers, default is 9.')
parser.add_argument('--ndf', type=int, default=64, help='discriminator model filter numbers, default is 64.')
parser.add_argument('--dl_num', type=int, default=3, \
help='discriminator model residual block numbers, default is 3.')
parser.add_argument('--slope', type=float, default=0.2, help='leakyrelu slope, default is 0.2.')
parser.add_argument('--norm_mode', type=str, default="instance", choices=("batch", "instance"), \
help='norm mode, default is instance.')
parser.add_argument('--lambda_A', type=float, default=10.0, \
help='weight for cycle loss (A -> B -> A), default is 10.')
parser.add_argument('--lambda_B', type=float, default=10.0, \
help='weight for cycle loss (B -> A -> B), default is 10.')
parser.add_argument('--lambda_idt', type=float, default=0.5, \
help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the '
'weight of the identity mapping loss. For example, if the weight of the identity loss '
'should be 10 times smaller than the weight of the reconstruction loss,'
'please set lambda_identity = 0.1, default is 0.5.')
parser.add_argument('--gan_mode', type=str, default='lsgan', choices=("lsgan", "vanilla"), \
help='the type of GAN loss, default is lsgan.')
parser.add_argument('--pad_mode', type=str, default='REFLECT', choices=("CONSTANT", "REFLECT", "SYMMETRIC"), \
help='the type of Pad, default is REFLECT.')
parser.add_argument('--need_dropout', type=ast.literal_eval, default=True, \
help='whether need dropout, default is True.')
# distillation learning parameters
parser.add_argument('--kd', type=ast.literal_eval, default=False, \
help='knowledge distillation learning or not, default is False.')
parser.add_argument('--t_ngf', type=int, default=64, \
help='teacher network generator model filter numbers when `kd` is True, default is 64.')
parser.add_argument('--t_gl_num', type=int, default=9, \
help='teacher network generator model residual block numbers when `kd` is True, default is 9.')
parser.add_argument('--t_slope', type=float, default=0.2, \
help='teacher network leakyrelu slope when `kd` is True, default is 0.2.')
parser.add_argument('--t_norm_mode', type=str, default="instance", choices=("batch", "instance"), \
help='teacher network norm mode when `kd` is True, default is instance.')
parser.add_argument("--GT_A_ckpt", type=str, default=None, \
help="teacher network pretrained checkpoint file path of G_A when `kd` is True.")
parser.add_argument("--GT_B_ckpt", type=str, default=None, \
help="teacher network pretrained checkpoint file path of G_B when `kd` is True.")
# additional parameters
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
parser.add_argument("--G_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_A.")
parser.add_argument("--G_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_B.")
parser.add_argument("--D_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_A.")
parser.add_argument("--D_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_B.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 10.")
parser.add_argument("--print_iter", type=int, default=100, help="log print iter, default is 100.")
parser.add_argument('--need_profiler', type=ast.literal_eval, default=False, \
help='whether need profiler, default is False.')
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
help='whether save graphs, default is False.')
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
help='models are saved here, default is ./outputs.')
parser.add_argument('--dataroot', default=None, \
help='path of images (should have subfolders trainA, trainB, testA, testB, etc).')
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
help='whether save imgs when epoch end, if True result images will generate in '
'`outputs_dir/imgs`, default is True.')
if phase == "export":
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
help='file format')
args = parser.parse_args()
if args.device_num > 1 and args.platform != "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=args.device_num)
init()
args.rank = get_rank()
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform,
save_graphs=args.save_graphs, device_id=args.device_id)
args.rank = 0
args.device_num = 1
if args.platform != "GPU":
args.pad_mode = "CONSTANT"
if phase != "train" and (args.G_A_ckpt is None or args.G_B_ckpt is None):
raise ValueError('Must set G_A_ckpt and G_B_ckpt in predict phase!')
if args.kd:
if args.GT_A_ckpt is None or args.GT_B_ckpt is None:
raise ValueError('Must set GT_A_ckpt, GT_B_ckpt in knowledge distillation!')
if args.norm_mode == "instance" or (args.kd and args.t_norm_mode == "instance"):
args.batch_size = 1
if args.dataroot is None and (phase in ["train", "predict"]):
raise ValueError('Must set dataroot!')
args.n_epochs_decay = args.max_epoch - args.n_epochs
args.phase = phase
return args

@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cityscape utils."""
import numpy as np
from PIL import Image
# label name and RGB color map.
label2color = {
'unlabeled': (0, 0, 0),
'ego vehicle': (0, 0, 0),
'rectification border': (0, 0, 0),
'out of roi': (0, 0, 0),
'static': (0, 0, 0),
'dynamic': (111, 74, 0),
'ground': (81, 0, 81),
'road': (128, 64, 128),
'sidewalk': (244, 35, 232),
'parking': (250, 170, 160),
'rail track': (230, 150, 140),
'building': (70, 70, 70),
'wall': (102, 102, 156),
'fence': (190, 153, 153),
'guard rail': (180, 165, 180),
'bridge': (150, 100, 100),
'tunnel': (150, 120, 90),
'pole': (153, 153, 153),
'polegroup': (153, 153, 153),
'traffic light': (250, 170, 30),
'traffic sign': (220, 220, 0),
'vegetation': (107, 142, 35),
'terrain': (152, 251, 152),
'sky': (70, 130, 180),
'person': (220, 20, 60),
'rider': (255, 0, 0),
'car': (0, 0, 142),
'truck': (0, 0, 70),
'bus': (0, 60, 100),
'caravan': (0, 0, 90),
'trailer': (0, 0, 110),
'train': (0, 80, 100),
'motorcycle': (0, 0, 230),
'bicycle': (119, 11, 32),
'license plate': (0, 0, 142)
}
def fast_hist(a, b, n):
k = np.where((a >= 0) & (a < n))[0]
bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
if len(bc) != n**2:
# ignore this example if dimension mismatch
return 0
return bc.reshape(n, n)
def get_scores(hist):
# Mean pixel accuracy
acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
# Per class accuracy
cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
# Per class IoU
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
class CityScapes:
"""CityScapes util class."""
def __init__(self):
self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck',
'bus', 'train', 'motorcycle', 'bicycle', 'unlabeled']
self.color_list = []
for name in self.classes:
self.color_list.append(label2color[name].color)
self.class_num = len(self.classes)
def get_id(self, img_path):
"""Get train id by img"""
img = np.array(Image.open(img_path).convert("RGB"))
w, h, _ = img.shape
img_tile = np.tile(img, (1, 1, self.class_num)).reshape(w, h, self.class_num, 3)
diff = np.abs(img_tile - self.color_list).sum(axis=-1)
ids = diff.argmin(axis=-1)
return ids

@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""prepare cityscapes dataset to cyclegan format"""
import os
import argparse
import glob
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--gtFine_dir', type=str, required=True,
help='Path to the Cityscapes gtFine directory.')
parser.add_argument('--leftImg8bit_dir', type=str, required=True,
help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
parser.add_argument('--output_dir', type=str, required=True,
default='./cityscapes',
help='Directory the output images will be written to.')
opt = parser.parse_args()
def load_resized_img(path):
"""Load image with RGB and resize to (256, 256)"""
return Image.open(path).convert('RGB').resize((256, 256))
def check_matching_pair(segmap_path, photo_path):
"""Check the segment images and photo images are matched or not."""
segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
assert segmap_identifier == photo_identifier, \
f"[{segmap_path}] and [{photo_path}] don't seem to be matching. Aborting."
def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
"""Process citycapes dataset to cyclegan dataset format."""
save_phase = 'test' if phase == 'val' else 'train'
savedir = os.path.join(output_dir, save_phase)
os.makedirs(savedir + 'A', exist_ok=True)
os.makedirs(savedir + 'B', exist_ok=True)
print(f"Directory structure prepared at {output_dir}")
segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
segmap_paths = glob.glob(segmap_expr)
segmap_paths = sorted(segmap_paths)
photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
photo_paths = glob.glob(photo_expr)
photo_paths = sorted(photo_paths)
assert len(segmap_paths) == len(photo_paths), \
"{} images that match [{}], and {} images that match [{}]. Aborting.".format(
len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
check_matching_pair(segmap_path, photo_path)
segmap = load_resized_img(segmap_path)
photo = load_resized_img(photo_path)
# data for cyclegan where the two images are stored at two distinct directories
savepath = os.path.join(savedir + 'A', f"{i + 1}.jpg")
photo.save(savepath)
savepath = os.path.join(savedir + 'B', f"{i + 1}.jpg")
segmap.save(savepath)
if i % (len(segmap_paths) // 10) == 0:
print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
if __name__ == '__main__':
print('Preparing Cityscapes Dataset for val phase')
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
print('Preparing Cityscapes Dataset for train phase')
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
print('Done')

@ -0,0 +1,144 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reporter class."""
import logging
import os
import time
from datetime import datetime
from mindspore.train.serialization import save_checkpoint
from .tools import save_image
class Reporter(logging.Logger):
"""
This class includes several functions that can save images/checkpoints and print/save logging information.
Args:
args (class): Option class.
"""
def __init__(self, args):
super(Reporter, self).__init__("cyclegan")
self.log_dir = os.path.join(args.outputs_dir, 'log')
self.imgs_dir = os.path.join(args.outputs_dir, "imgs")
self.ckpts_dir = os.path.join(args.outputs_dir, "ckpt")
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir, exist_ok=True)
if not os.path.exists(self.imgs_dir):
os.makedirs(self.imgs_dir, exist_ok=True)
if not os.path.exists(self.ckpts_dir):
os.makedirs(self.ckpts_dir, exist_ok=True)
self.rank = args.rank
self.save_checkpoint_epochs = args.save_checkpoint_epochs
self.save_imgs = args.save_imgs
# console handler
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
# file handler
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(self.rank)
self.log_fn = os.path.join(self.log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.addHandler(fh)
self.save_args(args)
self.step = 0
self.epoch = 0
self.dataset_size = args.dataset_size
self.print_iter = args.print_iter
self.G_loss = []
self.D_loss = []
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.logger.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def epoch_start(self):
self.step_start_time = time.time()
self.epoch_start_time = time.time()
self.step = 0
self.epoch += 1
self.G_loss = []
self.D_loss = []
def step_end(self, res_G, res_D):
"""print log when step end."""
self.step += 1
loss_D = float(res_D.asnumpy())
res = []
for item in res_G[2:]:
res.append(float(item.asnumpy()))
self.G_loss.append(res[0])
self.D_loss.append(loss_D)
if self.step % self.print_iter == 0:
step_cost = (time.time() - self.step_start_time) * 1000 / self.print_iter
losses = "G_loss: {:.2f}, D_loss:{:.2f}, loss_G_A: {:.2f}, loss_G_B: {:.2f}, loss_C_A: {:.2f},"\
"loss_C_B: {:.2f}, loss_idt_A: {:.2f}, loss_idt_B{:.2f}".format(
res[0], loss_D, res[1], res[2], res[3], res[4], res[5], res[6])
self.info("Epoch[{}] [{}/{}] step cost: {:.2f} ms, {}".format(
self.epoch, self.step, self.dataset_size, step_cost, losses))
self.step_start_time = time.time()
def epoch_end(self, net):
"""print log and save cgeckpoints when epoch end."""
epoch_cost = (time.time() - self.epoch_start_time) * 1000
pre_step_time = epoch_cost / self.dataset_size
mean_loss_G = sum(self.G_loss) / self.dataset_size
mean_loss_D = sum(self.D_loss) / self.dataset_size
self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, G_loss: {:.2f}, D_loss: {:.2f}".format(
self.epoch, epoch_cost, pre_step_time, mean_loss_G, mean_loss_D))
if self.epoch % self.save_checkpoint_epochs == 0 and self.rank == 0:
save_checkpoint(net.G.generator.G_A, os.path.join(self.ckpts_dir, f"G_A_{self.epoch}.ckpt"))
save_checkpoint(net.G.generator.G_B, os.path.join(self.ckpts_dir, f"G_B_{self.epoch}.ckpt"))
save_checkpoint(net.G.D_A, os.path.join(self.ckpts_dir, f"D_A_{self.epoch}.ckpt"))
save_checkpoint(net.G.D_B, os.path.join(self.ckpts_dir, f"D_B_{self.epoch}.ckpt"))
def visualizer(self, img_A, img_B, fake_A, fake_B):
if self.save_imgs and self.step % self.dataset_size == 0 and self.rank == 0:
save_image(img_A, os.path.join(self.imgs_dir, f"{self.epoch}_img_A.jpg"))
save_image(img_B, os.path.join(self.imgs_dir, f"{self.epoch}_img_B.jpg"))
save_image(fake_A, os.path.join(self.imgs_dir, f"{self.epoch}_fake_A.jpg"))
save_image(fake_B, os.path.join(self.imgs_dir, f"{self.epoch}_fake_B.jpg"))
def start_predict(self, direction):
self.predict_start_time = time.time()
self.direction = direction
self.info('==========start predict %s===============', self.direction)
def end_predict(self):
cost = (time.time() - self.predict_start_time) * 1000
pre_step_cost = cost / self.dataset_size
self.info('total {} imgs cost {:.2f} ms, pre img cost {:.2f}'.format(self.dataset_size, cost, pre_step_cost))
self.info('==========end predict %s===============\n', self.direction)

@ -0,0 +1,141 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for cyclegan."""
import random
import numpy as np
from PIL import Image
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
class ImagePool():
"""
This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""
Initialize the ImagePool class
Args:
pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created.
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""
Return an image from the pool.
Args:
images: the latest generated images from the generator
Returns images Tensor from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if isinstance(images, Tensor):
images = images.asnumpy()
if self.pool_size == 0: # if the buffer size is 0, do nothing
return Tensor(images)
return_images = []
for image in images:
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].copy()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = np.array(return_images) # collect all the images and return
if len(return_images.shape) != 4:
raise ValueError("img should be 4d, but get shape {}".format(return_images.shape))
return Tensor(return_images)
def save_image(img, img_path):
"""Save a numpy image to the disk
Parameters:
img (numpy array / Tensor): image to save.
image_path (str): the path of the image.
"""
if isinstance(img, Tensor):
img = decode_image(img)
elif not isinstance(img, np.ndarray):
raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))
img_pil = Image.fromarray(img)
img_pil.save(img_path)
def decode_image(img):
"""Decode a [1, C, H, W] Tensor to image numpy array."""
mean = 0.5 * 255
std = 0.5 * 255
return (img.asnumpy()[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
def get_lr(args):
"""Learning rate generator."""
if args.lr_policy == 'linear':
lrs = [args.lr] * args.dataset_size * args.n_epochs
lr_epoch = 0
for epoch in range(args.n_epochs_decay):
lr_epoch = args.lr * (args.n_epochs_decay - epoch) / args.n_epochs_decay
lrs += [lr_epoch] * args.dataset_size
lrs += [lr_epoch] * args.dataset_size * (args.max_epoch - args.n_epochs_decay - args.n_epochs)
return Tensor(np.array(lrs).astype(np.float32))
return args.lr
def load_ckpt(args, G_A, G_B, D_A=None, D_B=None):
"""Load parameter from checkpoint."""
if args.G_A_ckpt is not None:
param_GA = load_checkpoint(args.G_A_ckpt)
load_param_into_net(G_A, param_GA)
if args.G_B_ckpt is not None:
param_GB = load_checkpoint(args.G_B_ckpt)
load_param_into_net(G_B, param_GB)
if D_A is not None and args.D_A_ckpt is not None:
param_DA = load_checkpoint(args.D_A_ckpt)
load_param_into_net(D_A, param_DA)
if D_B is not None and args.D_B_ckpt is not None:
param_DB = load_checkpoint(args.D_B_ckpt)
load_param_into_net(D_B, param_DB)
def load_teacher_ckpt(net, ckpt_path, teacher, student):
"""Replace parameter name to teacher net and load parameter from checkpoint."""
param = load_checkpoint(ckpt_path)
new_param = {}
for k, v in param.items():
new_name = k.replace(student, teacher)
new_param_name = v.name.replace(student, teacher)
v.name = new_param_name
new_param[new_name] = v
load_param_into_net(net, new_param)

@ -0,0 +1,74 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN train."""
import mindspore.nn as nn
from mindspore.common import set_seed
from src.models import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD, \
DiscriminatorLoss, GeneratorLoss
from src.utils import get_lr, get_args, Reporter, ImagePool, load_ckpt
from src.dataset import create_dataset
set_seed(1)
def train():
"""Train function."""
args = get_args("train")
if args.need_profiler:
from mindspore.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
ds = create_dataset(args)
G_A = get_generator(args)
G_B = get_generator(args)
D_A = get_discriminator(args)
D_B = get_discriminator(args)
load_ckpt(args, G_A, G_B, D_A, D_B)
imgae_pool_A = ImagePool(args.pool_size)
imgae_pool_B = ImagePool(args.pool_size)
generator = Generator(G_A, G_B, args.lambda_idt > 0)
loss_D = DiscriminatorLoss(args, D_A, D_B)
loss_G = GeneratorLoss(args, generator, D_A, D_B)
optimizer_G = nn.Adam(generator.trainable_params(), get_lr(args), beta1=args.beta1)
optimizer_D = nn.Adam(loss_D.trainable_params(), get_lr(args), beta1=args.beta1)
net_G = TrainOneStepG(loss_G, generator, optimizer_G)
net_D = TrainOneStepD(loss_D, optimizer_D)
data_loader = ds.create_dict_iterator()
reporter = Reporter(args)
reporter.info('==========start training===============')
for _ in range(args.max_epoch):
reporter.epoch_start()
for data in data_loader:
img_A = data["image_A"]
img_B = data["image_B"]
res_G = net_G(img_A, img_B)
fake_A = res_G[0]
fake_B = res_G[1]
res_D = net_D(img_A, img_B, imgae_pool_A.query(fake_A), imgae_pool_B.query(fake_B))
reporter.step_end(res_G, res_D)
reporter.visualizer(img_A, img_B, fake_A, fake_B)
reporter.epoch_end(net_G)
if args.need_profiler:
profiler.analyse()
break
reporter.info('==========end training===============')
if __name__ == "__main__":
train()
Loading…
Cancel
Save