commit
31174b013e
@ -0,0 +1,54 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Eval use cityscape dataset."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from src.dataset import make_dataset
|
||||
from src.utils import CityScapes, fast_hist, get_scores
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset")
|
||||
parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
CS = CityScapes()
|
||||
cityscapes = make_dataset(args.cityscapes_dir)
|
||||
hist_perframe = np.zeros((CS.class_num, CS.class_num))
|
||||
for i, img_path in enumerate(cityscapes):
|
||||
if i % 100 == 0:
|
||||
print('Evaluating: %d/%d' % (i, len(cityscapes)))
|
||||
img_name = os.path.split(img_path)[1]
|
||||
ids1 = CS.get_id(os.path.join(args.cityscapes_dir, img_name))
|
||||
ids2 = CS.get_id(os.path.join(args.result_dir, img_name))
|
||||
hist_perframe += fast_hist(ids1.flatten(), ids2.flatten(), CS.class_num)
|
||||
|
||||
mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
|
||||
print(f"mean_pixel_acc: {mean_pixel_acc}, mean_class_acc: {mean_class_acc}, mean_class_iou: {mean_class_iou}")
|
||||
with open('./evaluation_results.txt', 'w') as f:
|
||||
f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc)
|
||||
f.write('Mean class accuracy: %f\n' % mean_class_acc)
|
||||
f.write('Mean class IoU: %f\n' % mean_class_iou)
|
||||
f.write('************ Per class numbers below ************\n')
|
||||
for i, cl in enumerate(CS.classes):
|
||||
while len(cl) < 15:
|
||||
cl = cl + ' '
|
||||
f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export file."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export
|
||||
from src.models import get_generator
|
||||
from src.utils import get_args, load_ckpt
|
||||
|
||||
args = get_args("export")
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
|
||||
if __name__ == '__main__':
|
||||
G_A = get_generator(args)
|
||||
G_B = get_generator(args)
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
G_A.set_train(True)
|
||||
G_B.set_train(True)
|
||||
load_ckpt(args, G_A, G_B)
|
||||
|
||||
input_shp = [1, 3, args.image_size, args.image_size]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
G_A_file = f"{args.file_name}_BtoA"
|
||||
export(G_A, input_array, file_name=G_A_file, file_format=args.file_format)
|
||||
G_B_file = f"{args.file_name}_AtoB"
|
||||
export(G_B, input_array, file_name=G_B_file, file_format=args.file_format)
|
@ -0,0 +1,27 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.models import get_generator
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "cyclegan":
|
||||
G_A = get_generator(*args, **kwargs)
|
||||
G_B = get_generator(*args, **kwargs)
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
G_A.set_train(True)
|
||||
G_B.set_train(True)
|
||||
return G_A, G_B
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN predict."""
|
||||
import os
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.models import get_generator
|
||||
from src.utils import get_args, load_ckpt, save_image, Reporter
|
||||
from src.dataset import create_dataset
|
||||
|
||||
def predict():
|
||||
"""Predict function."""
|
||||
args = get_args("predict")
|
||||
G_A = get_generator(args)
|
||||
G_B = get_generator(args)
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
G_A.set_train(True)
|
||||
G_B.set_train(True)
|
||||
load_ckpt(args, G_A, G_B)
|
||||
|
||||
imgs_out = os.path.join(args.outputs_dir, "predict")
|
||||
if not os.path.exists(imgs_out):
|
||||
os.makedirs(imgs_out)
|
||||
if not os.path.exists(os.path.join(imgs_out, "fake_A")):
|
||||
os.makedirs(os.path.join(imgs_out, "fake_A"))
|
||||
if not os.path.exists(os.path.join(imgs_out, "fake_B")):
|
||||
os.makedirs(os.path.join(imgs_out, "fake_B"))
|
||||
args.data_dir = 'testA'
|
||||
ds = create_dataset(args)
|
||||
reporter = Reporter(args)
|
||||
reporter.start_predict("A to B")
|
||||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_A = Tensor(data["image"])
|
||||
path_A = str(data["image_name"][0], encoding="utf-8")
|
||||
fake_B = G_A(img_A)
|
||||
save_image(fake_B, os.path.join(imgs_out, "fake_B", path_A))
|
||||
reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A))
|
||||
reporter.end_predict()
|
||||
args.data_dir = 'testB'
|
||||
ds = create_dataset(args)
|
||||
reporter.dataset_size = args.dataset_size
|
||||
reporter.start_predict("B to A")
|
||||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_B = Tensor(data["image"])
|
||||
path_B = str(data["image_name"][0], encoding="utf-8")
|
||||
fake_A = G_B(img_B)
|
||||
save_image(fake_A, os.path.join(imgs_out, "fake_A", path_B))
|
||||
reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B))
|
||||
reporter.end_predict()
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
@ -0,0 +1,17 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
||||
from .datasets import UnalignedDataset, ImageFolderDataset, make_dataset
|
||||
from .cyclegan_dataset import create_dataset
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN dataset."""
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from .distributed_sampler import DistributedSampler
|
||||
from .datasets import UnalignedDataset, ImageFolderDataset
|
||||
|
||||
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")):
|
||||
"""Create dataset"""
|
||||
dataroot = args.dataroot
|
||||
phase = args.phase
|
||||
batch_size = args.batch_size
|
||||
device_num = args.device_num
|
||||
rank = args.rank
|
||||
cores = multiprocessing.cpu_count()
|
||||
num_parallel_workers = min(8, int(cores / device_num))
|
||||
image_size = args.image_size
|
||||
mean = [0.5 * 255] * 3
|
||||
std = [0.5 * 255] * 3
|
||||
if phase == "train":
|
||||
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size)
|
||||
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
||||
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
||||
trans = [
|
||||
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
else:
|
||||
datadir = os.path.join(dataroot, args.data_dir)
|
||||
dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
trans = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(1, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
args.dataset_size = len(dataset)
|
||||
return ds
|
@ -0,0 +1,102 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN datasets."""
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
|
||||
|
||||
def is_image_file(filename):
|
||||
"""Judge whether it is a picture."""
|
||||
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir_path, max_dataset_size=float("inf")):
|
||||
"""Return image list in dir."""
|
||||
images = []
|
||||
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir_path)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
class UnalignedDataset:
|
||||
"""
|
||||
This dataset class can load unaligned/unpaired datasets.
|
||||
|
||||
Args:
|
||||
dataroot (str): Images root directory.
|
||||
phase (str): Train or test. It requires two directories in dataroot, like trainA and trainB to
|
||||
host training images from domain A '{dataroot}/trainA' and from domain B '{dataroot}/trainB' respectively.
|
||||
max_dataset_size (int): Maximum number of return image paths.
|
||||
|
||||
Returns:
|
||||
Two domain image path list.
|
||||
"""
|
||||
|
||||
def __init__(self, dataroot, phase, max_dataset_size=float("inf")):
|
||||
self.dir_A = os.path.join(dataroot, phase + 'A')
|
||||
self.dir_B = os.path.join(dataroot, phase + 'B')
|
||||
|
||||
self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA'
|
||||
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index % max(self.A_size, self.B_size) == 0:
|
||||
random.shuffle(self.A_paths)
|
||||
A_path = self.A_paths[index % self.A_size]
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
B_path = self.B_paths[index_B]
|
||||
A_img = np.array(Image.open(A_path).convert('RGB'))
|
||||
B_img = np.array(Image.open(B_path).convert('RGB'))
|
||||
|
||||
return A_img, B_img
|
||||
|
||||
def __len__(self):
|
||||
return max(self.A_size, self.B_size)
|
||||
|
||||
class ImageFolderDataset:
|
||||
"""
|
||||
This dataset class can load images from image folder.
|
||||
|
||||
Args:
|
||||
dataroot (str): Images root directory.
|
||||
max_dataset_size (int): Maximum number of return image paths.
|
||||
|
||||
Returns:
|
||||
Image path list.
|
||||
"""
|
||||
|
||||
def __init__(self, dataroot, max_dataset_size=float("inf")):
|
||||
self.dataroot = dataroot
|
||||
self.paths = sorted(make_dataset(dataroot, max_dataset_size))
|
||||
self.size = len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.paths[index % self.size]
|
||||
img = np.array(Image.open(img_path).convert('RGB'))
|
||||
|
||||
return img, os.path.split(img_path)[1]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
@ -0,0 +1,18 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
||||
from .cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD
|
||||
from .losses import DiscriminatorLoss, GeneratorLoss, GANLoss
|
||||
from .networks import init_weights
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,175 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN losses"""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from .cycle_gan import get_generator
|
||||
from ..utils import load_teacher_ckpt
|
||||
|
||||
class BCEWithLogits(nn.Cell):
|
||||
"""
|
||||
BCEWithLogits creates a criterion to measure the Binary Cross Entropy between the true labels and
|
||||
predicted labels with sigmoid logits.
|
||||
|
||||
Args:
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
|
||||
Otherwise, the output is a scalar.
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
super(BCEWithLogits, self).__init__()
|
||||
if reduction is None:
|
||||
reduction = 'none'
|
||||
if reduction not in ('mean', 'sum', 'none'):
|
||||
raise ValueError(f"reduction method for {reduction.lower()} is not supported")
|
||||
|
||||
self.loss = ops.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce = False
|
||||
if reduction == 'sum':
|
||||
self.reduce_mode = ops.ReduceSum()
|
||||
self.reduce = True
|
||||
elif reduction == 'mean':
|
||||
self.reduce_mode = ops.ReduceMean()
|
||||
self.reduce = True
|
||||
def construct(self, predict, target):
|
||||
loss = self.loss(predict, target)
|
||||
if self.reduce:
|
||||
loss = self.reduce_mode(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class GANLoss(nn.Cell):
|
||||
"""
|
||||
Cycle GAN loss factory.
|
||||
|
||||
Args:
|
||||
mode (str): The type of GAN objective. It currently supports 'vanilla', 'lsgan'. Default: 'lsgan'.
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
|
||||
Otherwise, the output is a scalar.
|
||||
"""
|
||||
def __init__(self, mode="lsgan", reduction='mean'):
|
||||
super(GANLoss, self).__init__()
|
||||
self.loss = None
|
||||
self.ones = ops.OnesLike()
|
||||
if mode == "lsgan":
|
||||
self.loss = nn.MSELoss(reduction)
|
||||
elif mode == "vanilla":
|
||||
self.loss = BCEWithLogits(reduction)
|
||||
else:
|
||||
raise NotImplementedError(f'GANLoss {mode} not recognized, we support lsgan and vanilla.')
|
||||
|
||||
def construct(self, predict, target):
|
||||
target = ops.cast(target, ops.dtype(predict))
|
||||
target = self.ones(predict) * target
|
||||
loss = self.loss(predict, target)
|
||||
return loss
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Cell):
|
||||
"""
|
||||
Cycle GAN generator loss.
|
||||
|
||||
Args:
|
||||
args (class): Option class.
|
||||
generator (Cell): Generator of CycleGAN.
|
||||
D_A (Cell): The discriminator network of domain A to domain B.
|
||||
D_B (Cell): The discriminator network of domain B to domain A.
|
||||
|
||||
Outputs:
|
||||
Tuple Tensor, the losses of generator.
|
||||
"""
|
||||
def __init__(self, args, generator, D_A, D_B):
|
||||
super(GeneratorLoss, self).__init__()
|
||||
self.lambda_A = args.lambda_A
|
||||
self.lambda_B = args.lambda_B
|
||||
self.lambda_idt = args.lambda_idt
|
||||
self.use_identity = args.lambda_idt > 0
|
||||
self.dis_loss = GANLoss(args.gan_mode)
|
||||
self.rec_loss = nn.L1Loss("mean")
|
||||
self.generator = generator
|
||||
self.D_A = D_A
|
||||
self.D_B = D_B
|
||||
self.true = Tensor(True, mstype.bool_)
|
||||
self.kd = args.kd
|
||||
if self.kd:
|
||||
self.GT_A = get_generator(args, True)
|
||||
load_teacher_ckpt(self.GT_A, args.GT_A_ckpt, "GT_A", "G_A")
|
||||
self.GT_B = get_generator(args, True)
|
||||
load_teacher_ckpt(self.GT_B, args.GT_B_ckpt, "GT_B", "G_B")
|
||||
self.GT_A.set_train(True)
|
||||
self.GT_B.set_train(True)
|
||||
|
||||
def construct(self, img_A, img_B):
|
||||
"""If use_identity, identity loss will be used."""
|
||||
fake_A, fake_B, rec_A, rec_B, identity_A, identity_B = self.generator(img_A, img_B)
|
||||
loss_G_A = self.dis_loss(self.D_B(fake_B), self.true)
|
||||
loss_G_B = self.dis_loss(self.D_A(fake_A), self.true)
|
||||
loss_C_A = self.rec_loss(rec_A, img_A) * self.lambda_A
|
||||
loss_C_B = self.rec_loss(rec_B, img_B) * self.lambda_B
|
||||
if self.use_identity:
|
||||
loss_idt_A = self.rec_loss(identity_A, img_A) * self.lambda_A * self.lambda_idt
|
||||
loss_idt_B = self.rec_loss(identity_B, img_B) * self.lambda_B * self.lambda_idt
|
||||
else:
|
||||
loss_idt_A = 0
|
||||
loss_idt_B = 0
|
||||
loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B
|
||||
if self.kd:
|
||||
teacher_A = self.GT_B(img_B)
|
||||
teacher_B = self.GT_A(img_A)
|
||||
kd_loss_A = self.rec_loss(teacher_A, fake_A) * self.lambda_A * 5
|
||||
kd_loss_B = self.rec_loss(teacher_B, fake_B) * self.lambda_A * 5
|
||||
loss_G += kd_loss_A + kd_loss_B
|
||||
return (fake_A, fake_B, loss_G, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B)
|
||||
|
||||
class DiscriminatorLoss(nn.Cell):
|
||||
"""
|
||||
Cycle GAN discriminator loss.
|
||||
|
||||
Args:
|
||||
args (class): option class.
|
||||
D_A (Cell): The discriminator network of domain A to domain B.
|
||||
D_B (Cell): The discriminator network of domain B to domain A.
|
||||
|
||||
Outputs:
|
||||
Tuple Tensor, the loss of discriminator.
|
||||
"""
|
||||
def __init__(self, args, D_A, D_B):
|
||||
super(DiscriminatorLoss, self).__init__()
|
||||
self.D_A = D_A
|
||||
self.D_B = D_B
|
||||
self.false = Tensor(False, mstype.bool_)
|
||||
self.true = Tensor(True, mstype.bool_)
|
||||
self.dis_loss = GANLoss(args.gan_mode)
|
||||
self.rec_loss = nn.L1Loss("mean")
|
||||
|
||||
def construct(self, img_A, img_B, fake_A, fake_B):
|
||||
D_fake_A = self.D_A(fake_A)
|
||||
D_img_A = self.D_A(img_A)
|
||||
D_fake_B = self.D_B(fake_B)
|
||||
D_img_B = self.D_B(img_B)
|
||||
loss_D_A = self.dis_loss(D_fake_A, self.false) + self.dis_loss(D_img_A, self.true)
|
||||
loss_D_B = self.dis_loss(D_fake_B, self.false) + self.dis_loss(D_img_B, self.true)
|
||||
loss_D = (loss_D_A + loss_D_B) * 0.5
|
||||
return loss_D
|
@ -0,0 +1,156 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN network."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(init.initializer(init.Normal(init_gain)))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain)))
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif isinstance(cell, nn.BatchNorm2d):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
class ConvNormReLU(nn.Cell):
|
||||
"""
|
||||
Convolution fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
kernel_size (int): Input kernel size. Default: 4.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 2.
|
||||
alpha (float): Slope of LackyReLU. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
use_relu (bool): Use relu or not. Default: True.
|
||||
padding (int): Pad size, if it is None, it will calculate by kernel_size. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
alpha=0.2,
|
||||
norm_mode='batch',
|
||||
pad_mode='CONSTANT',
|
||||
use_relu=True,
|
||||
padding=None):
|
||||
super(ConvNormReLU, self).__init__()
|
||||
norm = nn.BatchNorm2d(out_planes)
|
||||
if norm_mode == 'instance':
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
norm = nn.BatchNorm2d(out_planes, affine=False)
|
||||
has_bias = (norm_mode == 'instance')
|
||||
if padding is None:
|
||||
padding = (kernel_size - 1) // 2
|
||||
if pad_mode == 'CONSTANT':
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
|
||||
has_bias=has_bias, padding=padding)
|
||||
layers = [conv, norm]
|
||||
else:
|
||||
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
|
||||
pad = nn.Pad(paddings=paddings, mode=pad_mode)
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
|
||||
layers = [pad, conv, norm]
|
||||
if use_relu:
|
||||
relu = nn.ReLU()
|
||||
if alpha > 0:
|
||||
relu = nn.LeakyReLU(alpha)
|
||||
layers.append(relu)
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class ConvTransposeNormReLU(nn.Cell):
|
||||
"""
|
||||
ConvTranspose2d fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
kernel_size (int): Input kernel size. Default: 4.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 2.
|
||||
alpha (float): Slope of LackyReLU. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
use_relu (bool): use relu or not. Default: True.
|
||||
padding (int): pad size, if it is None, it will calculate by kernel_size. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
alpha=0.2,
|
||||
norm_mode='batch',
|
||||
pad_mode='CONSTANT',
|
||||
use_relu=True,
|
||||
padding=None):
|
||||
super(ConvTransposeNormReLU, self).__init__()
|
||||
conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride=stride, pad_mode='same')
|
||||
norm = nn.BatchNorm2d(out_planes)
|
||||
if norm_mode == 'instance':
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
norm = nn.BatchNorm2d(out_planes, affine=False)
|
||||
has_bias = (norm_mode == 'instance')
|
||||
if padding is None:
|
||||
padding = (kernel_size - 1) // 2
|
||||
if pad_mode == 'CONSTANT':
|
||||
conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, pad_mode='same', has_bias=has_bias)
|
||||
layers = [conv, norm]
|
||||
else:
|
||||
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
|
||||
pad = nn.Pad(paddings=paddings, mode=pad_mode)
|
||||
conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
|
||||
layers = [pad, conv, norm]
|
||||
if use_relu:
|
||||
relu = nn.ReLU()
|
||||
if alpha > 0:
|
||||
relu = nn.LeakyReLU(alpha)
|
||||
layers.append(relu)
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet Generator."""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from .networks import ConvNormReLU, ConvTransposeNormReLU
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet residual block definition.
|
||||
|
||||
Args:
|
||||
dim (int): Input and output channel.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
|
||||
self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
|
||||
self.dropout = dropout
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
if self.dropout:
|
||||
out = self.dropout(out)
|
||||
out = self.conv2(out)
|
||||
return x + out
|
||||
|
||||
|
||||
class ResNetGenerator(nn.Cell):
|
||||
"""
|
||||
ResNet Generator of GAN.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
ngf (int): Output channel.
|
||||
n_layers (int): The number of ConvNormReLU blocks.
|
||||
alpha (float): LeakyRelu slope. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self, in_planes=3, ngf=64, n_layers=9, alpha=0.2, norm_mode='batch', dropout=False,
|
||||
pad_mode="CONSTANT"):
|
||||
super(ResNetGenerator, self).__init__()
|
||||
self.conv_in = ConvNormReLU(in_planes, ngf, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
|
||||
self.down_1 = ConvNormReLU(ngf, ngf * 2, 3, 2, alpha, norm_mode)
|
||||
self.down_2 = ConvNormReLU(ngf * 2, ngf * 4, 3, 2, alpha, norm_mode)
|
||||
layers = [ResidualBlock(ngf * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
|
||||
self.residuals = nn.SequentialCell(layers)
|
||||
self.up_2 = ConvTransposeNormReLU(ngf * 4, ngf * 2, 3, 2, alpha, norm_mode)
|
||||
self.up_1 = ConvTransposeNormReLU(ngf * 2, ngf, 3, 2, alpha, norm_mode)
|
||||
if pad_mode == "CONSTANT":
|
||||
self.conv_out = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad', padding=3)
|
||||
else:
|
||||
pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
|
||||
conv = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad')
|
||||
self.conv_out = nn.SequentialCell([pad, conv])
|
||||
self.activate = ops.Tanh()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.down_1(x)
|
||||
x = self.down_2(x)
|
||||
x = self.residuals(x)
|
||||
x = self.up_2(x)
|
||||
x = self.up_1(x)
|
||||
output = self.conv_out(x)
|
||||
return self.activate(output)
|
@ -0,0 +1,124 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""UNet Generator."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
class UnetGenerator(nn.Cell):
|
||||
"""
|
||||
Unet-based generator.
|
||||
|
||||
Args:
|
||||
in_planes (int): the number of channels in input images.
|
||||
out_planes (int): the number of channels in output images.
|
||||
ngf (int): the number of filters in the last conv layer.
|
||||
n_layers (int): the number of downsamplings in UNet.
|
||||
alpha (float): LeakyRelu slope. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self, in_planes, out_planes, ngf=64, n_layers=7, alpha=0.2, norm_mode='bn', dropout=False):
|
||||
super(UnetGenerator, self).__init__()
|
||||
# construct unet structure
|
||||
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
|
||||
norm_mode=norm_mode, innermost=True)
|
||||
for _ in range(n_layers - 5):
|
||||
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
|
||||
norm_mode=norm_mode, dropout=dropout)
|
||||
# gradually reduce the number of filters from ngf * 8 to ngf
|
||||
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
|
||||
norm_mode=norm_mode)
|
||||
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
|
||||
norm_mode=norm_mode)
|
||||
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
|
||||
self.model = UnetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
|
||||
outermost=True, norm_mode=norm_mode)
|
||||
|
||||
def construct(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class UnetSkipConnectionBlock(nn.Cell):
|
||||
"""Unet submodule with skip connection.
|
||||
|
||||
Args:
|
||||
outer_nc (int): The number of filters in the outer conv layer
|
||||
inner_nc (int): The number of filters in the inner conv layer
|
||||
in_planes (int): The number of channels in input images/features
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
submodule (Cell): Previously defined submodules
|
||||
outermost (bool): If this module is the outermost module
|
||||
innermost (bool): If this module is the innermost module
|
||||
alpha (float): LeakyRelu slope. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
|
||||
submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
|
||||
super(UnetSkipConnectionBlock, self).__init__()
|
||||
downnorm = nn.BatchNorm2d(inner_nc)
|
||||
upnorm = nn.BatchNorm2d(outer_nc)
|
||||
use_bias = False
|
||||
if norm_mode == 'instance':
|
||||
downnorm = nn.BatchNorm2d(inner_nc, affine=False)
|
||||
upnorm = nn.BatchNorm2d(outer_nc, affine=False)
|
||||
use_bias = True
|
||||
if in_planes is None:
|
||||
in_planes = outer_nc
|
||||
downconv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
|
||||
stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
|
||||
downrelu = nn.LeakyReLU(alpha)
|
||||
uprelu = nn.ReLU()
|
||||
|
||||
if outermost:
|
||||
upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
|
||||
kernel_size=4, stride=2,
|
||||
padding=1, pad_mode='pad')
|
||||
down = [downconv]
|
||||
up = [uprelu, upconv, nn.Tanh()]
|
||||
model = down + [submodule] + up
|
||||
elif innermost:
|
||||
upconv = nn.Conv2dTranspose(inner_nc, outer_nc,
|
||||
kernel_size=4, stride=2,
|
||||
padding=1, has_bias=use_bias, pad_mode='pad')
|
||||
down = [downrelu, downconv]
|
||||
up = [uprelu, upconv, upnorm]
|
||||
model = down + up
|
||||
else:
|
||||
upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
|
||||
kernel_size=4, stride=2,
|
||||
padding=1, has_bias=use_bias, pad_mode='pad')
|
||||
down = [downrelu, downconv, downnorm]
|
||||
up = [uprelu, upconv, upnorm]
|
||||
|
||||
model = down + [submodule] + up
|
||||
if dropout:
|
||||
model.append(nn.Dropout(0.5))
|
||||
|
||||
self.model = nn.SequentialCell(model)
|
||||
self.skip_connections = not outermost
|
||||
self.concat = ops.Concat(axis=1)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.model(x)
|
||||
if self.skip_connections:
|
||||
out = self.concat((out, x))
|
||||
return out
|
@ -0,0 +1,19 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
||||
from .args import get_args
|
||||
from .reporter import Reporter
|
||||
from .tools import get_lr, load_teacher_ckpt, ImagePool, load_ckpt, save_image
|
||||
from .cityscapes_utils import CityScapes, fast_hist, get_scores
|
@ -0,0 +1,145 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""get args."""
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
|
||||
def get_args(phase):
|
||||
"""Define the common options that are used in both training and test."""
|
||||
parser = argparse.ArgumentParser(description='Cycle GAN.')
|
||||
# basic parameters
|
||||
parser.add_argument('--model', type=str, default="resnet", choices=("resnet", "unet"), \
|
||||
help='generator model, should be in [resnet, unet].')
|
||||
parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
|
||||
help='run platform, only support GPU, CPU and Ascend')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
|
||||
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate, default is 0.0002.")
|
||||
parser.add_argument('--pool_size', type=int, default=50, \
|
||||
help='the size of image buffer that stores previously generated images, default is 50.')
|
||||
parser.add_argument('--lr_policy', type=str, default='linear', choices=("linear", "constant"), \
|
||||
help='learning rate policy, default is linear')
|
||||
parser.add_argument("--image_size", type=int, default=256, help="input image_size, default is 256.")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.')
|
||||
parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
|
||||
parser.add_argument('--n_epochs', type=int, default=100, \
|
||||
help='number of epochs with the initial learning rate, default is 100')
|
||||
parser.add_argument("--beta1", type=float, default=0.5, help="Adam beta1, default is 0.5.")
|
||||
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"), \
|
||||
help='network initialization, default is normal.')
|
||||
parser.add_argument('--init_gain', type=float, default=0.02, \
|
||||
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
|
||||
|
||||
# model parameters
|
||||
parser.add_argument('--in_planes', type=int, default=3, help='input channels, default is 3.')
|
||||
parser.add_argument('--ngf', type=int, default=64, help='generator model filter numbers, default is 64.')
|
||||
parser.add_argument('--gl_num', type=int, default=9, help='generator model residual block numbers, default is 9.')
|
||||
parser.add_argument('--ndf', type=int, default=64, help='discriminator model filter numbers, default is 64.')
|
||||
parser.add_argument('--dl_num', type=int, default=3, \
|
||||
help='discriminator model residual block numbers, default is 3.')
|
||||
parser.add_argument('--slope', type=float, default=0.2, help='leakyrelu slope, default is 0.2.')
|
||||
parser.add_argument('--norm_mode', type=str, default="instance", choices=("batch", "instance"), \
|
||||
help='norm mode, default is instance.')
|
||||
parser.add_argument('--lambda_A', type=float, default=10.0, \
|
||||
help='weight for cycle loss (A -> B -> A), default is 10.')
|
||||
parser.add_argument('--lambda_B', type=float, default=10.0, \
|
||||
help='weight for cycle loss (B -> A -> B), default is 10.')
|
||||
parser.add_argument('--lambda_idt', type=float, default=0.5, \
|
||||
help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the '
|
||||
'weight of the identity mapping loss. For example, if the weight of the identity loss '
|
||||
'should be 10 times smaller than the weight of the reconstruction loss,'
|
||||
'please set lambda_identity = 0.1, default is 0.5.')
|
||||
parser.add_argument('--gan_mode', type=str, default='lsgan', choices=("lsgan", "vanilla"), \
|
||||
help='the type of GAN loss, default is lsgan.')
|
||||
parser.add_argument('--pad_mode', type=str, default='REFLECT', choices=("CONSTANT", "REFLECT", "SYMMETRIC"), \
|
||||
help='the type of Pad, default is REFLECT.')
|
||||
parser.add_argument('--need_dropout', type=ast.literal_eval, default=True, \
|
||||
help='whether need dropout, default is True.')
|
||||
|
||||
# distillation learning parameters
|
||||
parser.add_argument('--kd', type=ast.literal_eval, default=False, \
|
||||
help='knowledge distillation learning or not, default is False.')
|
||||
parser.add_argument('--t_ngf', type=int, default=64, \
|
||||
help='teacher network generator model filter numbers when `kd` is True, default is 64.')
|
||||
parser.add_argument('--t_gl_num', type=int, default=9, \
|
||||
help='teacher network generator model residual block numbers when `kd` is True, default is 9.')
|
||||
parser.add_argument('--t_slope', type=float, default=0.2, \
|
||||
help='teacher network leakyrelu slope when `kd` is True, default is 0.2.')
|
||||
parser.add_argument('--t_norm_mode', type=str, default="instance", choices=("batch", "instance"), \
|
||||
help='teacher network norm mode when `kd` is True, default is instance.')
|
||||
parser.add_argument("--GT_A_ckpt", type=str, default=None, \
|
||||
help="teacher network pretrained checkpoint file path of G_A when `kd` is True.")
|
||||
parser.add_argument("--GT_B_ckpt", type=str, default=None, \
|
||||
help="teacher network pretrained checkpoint file path of G_B when `kd` is True.")
|
||||
|
||||
# additional parameters
|
||||
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
|
||||
parser.add_argument("--G_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_A.")
|
||||
parser.add_argument("--G_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_B.")
|
||||
parser.add_argument("--D_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_A.")
|
||||
parser.add_argument("--D_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_B.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 10.")
|
||||
parser.add_argument("--print_iter", type=int, default=100, help="log print iter, default is 100.")
|
||||
parser.add_argument('--need_profiler', type=ast.literal_eval, default=False, \
|
||||
help='whether need profiler, default is False.')
|
||||
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
|
||||
help='whether save graphs, default is False.')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
|
||||
help='models are saved here, default is ./outputs.')
|
||||
parser.add_argument('--dataroot', default=None, \
|
||||
help='path of images (should have subfolders trainA, trainB, testA, testB, etc).')
|
||||
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
|
||||
help='whether save imgs when epoch end, if True result images will generate in '
|
||||
'`outputs_dir/imgs`, default is True.')
|
||||
if phase == "export":
|
||||
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
||||
help='file format')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device_num > 1 and args.platform != "CPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=args.device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform,
|
||||
save_graphs=args.save_graphs, device_id=args.device_id)
|
||||
args.rank = 0
|
||||
args.device_num = 1
|
||||
|
||||
if args.platform != "GPU":
|
||||
args.pad_mode = "CONSTANT"
|
||||
|
||||
if phase != "train" and (args.G_A_ckpt is None or args.G_B_ckpt is None):
|
||||
raise ValueError('Must set G_A_ckpt and G_B_ckpt in predict phase!')
|
||||
|
||||
if args.kd:
|
||||
if args.GT_A_ckpt is None or args.GT_B_ckpt is None:
|
||||
raise ValueError('Must set GT_A_ckpt, GT_B_ckpt in knowledge distillation!')
|
||||
|
||||
if args.norm_mode == "instance" or (args.kd and args.t_norm_mode == "instance"):
|
||||
args.batch_size = 1
|
||||
|
||||
if args.dataroot is None and (phase in ["train", "predict"]):
|
||||
raise ValueError('Must set dataroot!')
|
||||
|
||||
args.n_epochs_decay = args.max_epoch - args.n_epochs
|
||||
args.phase = phase
|
||||
return args
|
@ -0,0 +1,95 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cityscape utils."""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# label name and RGB color map.
|
||||
label2color = {
|
||||
'unlabeled': (0, 0, 0),
|
||||
'ego vehicle': (0, 0, 0),
|
||||
'rectification border': (0, 0, 0),
|
||||
'out of roi': (0, 0, 0),
|
||||
'static': (0, 0, 0),
|
||||
'dynamic': (111, 74, 0),
|
||||
'ground': (81, 0, 81),
|
||||
'road': (128, 64, 128),
|
||||
'sidewalk': (244, 35, 232),
|
||||
'parking': (250, 170, 160),
|
||||
'rail track': (230, 150, 140),
|
||||
'building': (70, 70, 70),
|
||||
'wall': (102, 102, 156),
|
||||
'fence': (190, 153, 153),
|
||||
'guard rail': (180, 165, 180),
|
||||
'bridge': (150, 100, 100),
|
||||
'tunnel': (150, 120, 90),
|
||||
'pole': (153, 153, 153),
|
||||
'polegroup': (153, 153, 153),
|
||||
'traffic light': (250, 170, 30),
|
||||
'traffic sign': (220, 220, 0),
|
||||
'vegetation': (107, 142, 35),
|
||||
'terrain': (152, 251, 152),
|
||||
'sky': (70, 130, 180),
|
||||
'person': (220, 20, 60),
|
||||
'rider': (255, 0, 0),
|
||||
'car': (0, 0, 142),
|
||||
'truck': (0, 0, 70),
|
||||
'bus': (0, 60, 100),
|
||||
'caravan': (0, 0, 90),
|
||||
'trailer': (0, 0, 110),
|
||||
'train': (0, 80, 100),
|
||||
'motorcycle': (0, 0, 230),
|
||||
'bicycle': (119, 11, 32),
|
||||
'license plate': (0, 0, 142)
|
||||
}
|
||||
|
||||
def fast_hist(a, b, n):
|
||||
k = np.where((a >= 0) & (a < n))[0]
|
||||
bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
|
||||
if len(bc) != n**2:
|
||||
# ignore this example if dimension mismatch
|
||||
return 0
|
||||
return bc.reshape(n, n)
|
||||
|
||||
def get_scores(hist):
|
||||
# Mean pixel accuracy
|
||||
acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
|
||||
# Per class accuracy
|
||||
cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
|
||||
# Per class IoU
|
||||
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
|
||||
return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
|
||||
|
||||
class CityScapes:
|
||||
"""CityScapes util class."""
|
||||
def __init__(self):
|
||||
self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence',
|
||||
'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck',
|
||||
'bus', 'train', 'motorcycle', 'bicycle', 'unlabeled']
|
||||
self.color_list = []
|
||||
for name in self.classes:
|
||||
self.color_list.append(label2color[name].color)
|
||||
self.class_num = len(self.classes)
|
||||
|
||||
def get_id(self, img_path):
|
||||
"""Get train id by img"""
|
||||
img = np.array(Image.open(img_path).convert("RGB"))
|
||||
w, h, _ = img.shape
|
||||
img_tile = np.tile(img, (1, 1, self.class_num)).reshape(w, h, self.class_num, 3)
|
||||
diff = np.abs(img_tile - self.color_list).sum(axis=-1)
|
||||
ids = diff.argmin(axis=-1)
|
||||
return ids
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""prepare cityscapes dataset to cyclegan format"""
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
from PIL import Image
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gtFine_dir', type=str, required=True,
|
||||
help='Path to the Cityscapes gtFine directory.')
|
||||
parser.add_argument('--leftImg8bit_dir', type=str, required=True,
|
||||
help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
default='./cityscapes',
|
||||
help='Directory the output images will be written to.')
|
||||
opt = parser.parse_args()
|
||||
|
||||
def load_resized_img(path):
|
||||
"""Load image with RGB and resize to (256, 256)"""
|
||||
return Image.open(path).convert('RGB').resize((256, 256))
|
||||
|
||||
def check_matching_pair(segmap_path, photo_path):
|
||||
"""Check the segment images and photo images are matched or not."""
|
||||
segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
|
||||
photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
|
||||
|
||||
assert segmap_identifier == photo_identifier, \
|
||||
f"[{segmap_path}] and [{photo_path}] don't seem to be matching. Aborting."
|
||||
|
||||
|
||||
def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
|
||||
"""Process citycapes dataset to cyclegan dataset format."""
|
||||
save_phase = 'test' if phase == 'val' else 'train'
|
||||
savedir = os.path.join(output_dir, save_phase)
|
||||
os.makedirs(savedir + 'A', exist_ok=True)
|
||||
os.makedirs(savedir + 'B', exist_ok=True)
|
||||
print(f"Directory structure prepared at {output_dir}")
|
||||
|
||||
segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
|
||||
segmap_paths = glob.glob(segmap_expr)
|
||||
segmap_paths = sorted(segmap_paths)
|
||||
|
||||
photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
|
||||
photo_paths = glob.glob(photo_expr)
|
||||
photo_paths = sorted(photo_paths)
|
||||
|
||||
assert len(segmap_paths) == len(photo_paths), \
|
||||
"{} images that match [{}], and {} images that match [{}]. Aborting.".format(
|
||||
len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
|
||||
|
||||
for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
|
||||
check_matching_pair(segmap_path, photo_path)
|
||||
segmap = load_resized_img(segmap_path)
|
||||
photo = load_resized_img(photo_path)
|
||||
|
||||
# data for cyclegan where the two images are stored at two distinct directories
|
||||
savepath = os.path.join(savedir + 'A', f"{i + 1}.jpg")
|
||||
photo.save(savepath)
|
||||
savepath = os.path.join(savedir + 'B', f"{i + 1}.jpg")
|
||||
segmap.save(savepath)
|
||||
|
||||
if i % (len(segmap_paths) // 10) == 0:
|
||||
print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Preparing Cityscapes Dataset for val phase')
|
||||
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
|
||||
print('Preparing Cityscapes Dataset for train phase')
|
||||
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
|
||||
print('Done')
|
@ -0,0 +1,141 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for cyclegan."""
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
class ImagePool():
|
||||
"""
|
||||
This class implements an image buffer that stores previously generated images.
|
||||
|
||||
This buffer enables us to update discriminators using a history of generated images
|
||||
rather than the ones produced by the latest generators.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_size):
|
||||
"""
|
||||
Initialize the ImagePool class
|
||||
|
||||
Args:
|
||||
pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created.
|
||||
"""
|
||||
self.pool_size = pool_size
|
||||
if self.pool_size > 0: # create an empty pool
|
||||
self.num_imgs = 0
|
||||
self.images = []
|
||||
|
||||
def query(self, images):
|
||||
"""
|
||||
Return an image from the pool.
|
||||
|
||||
Args:
|
||||
images: the latest generated images from the generator
|
||||
|
||||
Returns images Tensor from the buffer.
|
||||
|
||||
By 50/100, the buffer will return input images.
|
||||
By 50/100, the buffer will return images previously stored in the buffer,
|
||||
and insert the current images to the buffer.
|
||||
"""
|
||||
if isinstance(images, Tensor):
|
||||
images = images.asnumpy()
|
||||
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
||||
return Tensor(images)
|
||||
return_images = []
|
||||
for image in images:
|
||||
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
p = random.uniform(0, 1)
|
||||
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
||||
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
||||
tmp = self.images[random_id].copy()
|
||||
self.images[random_id] = image
|
||||
return_images.append(tmp)
|
||||
else: # by another 50% chance, the buffer will return the current image
|
||||
return_images.append(image)
|
||||
return_images = np.array(return_images) # collect all the images and return
|
||||
if len(return_images.shape) != 4:
|
||||
raise ValueError("img should be 4d, but get shape {}".format(return_images.shape))
|
||||
return Tensor(return_images)
|
||||
|
||||
|
||||
def save_image(img, img_path):
|
||||
"""Save a numpy image to the disk
|
||||
|
||||
Parameters:
|
||||
img (numpy array / Tensor): image to save.
|
||||
image_path (str): the path of the image.
|
||||
"""
|
||||
if isinstance(img, Tensor):
|
||||
img = decode_image(img)
|
||||
elif not isinstance(img, np.ndarray):
|
||||
raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))
|
||||
img_pil = Image.fromarray(img)
|
||||
img_pil.save(img_path)
|
||||
|
||||
|
||||
def decode_image(img):
|
||||
"""Decode a [1, C, H, W] Tensor to image numpy array."""
|
||||
mean = 0.5 * 255
|
||||
std = 0.5 * 255
|
||||
return (img.asnumpy()[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
|
||||
|
||||
|
||||
def get_lr(args):
|
||||
"""Learning rate generator."""
|
||||
if args.lr_policy == 'linear':
|
||||
lrs = [args.lr] * args.dataset_size * args.n_epochs
|
||||
lr_epoch = 0
|
||||
for epoch in range(args.n_epochs_decay):
|
||||
lr_epoch = args.lr * (args.n_epochs_decay - epoch) / args.n_epochs_decay
|
||||
lrs += [lr_epoch] * args.dataset_size
|
||||
lrs += [lr_epoch] * args.dataset_size * (args.max_epoch - args.n_epochs_decay - args.n_epochs)
|
||||
return Tensor(np.array(lrs).astype(np.float32))
|
||||
return args.lr
|
||||
|
||||
|
||||
def load_ckpt(args, G_A, G_B, D_A=None, D_B=None):
|
||||
"""Load parameter from checkpoint."""
|
||||
if args.G_A_ckpt is not None:
|
||||
param_GA = load_checkpoint(args.G_A_ckpt)
|
||||
load_param_into_net(G_A, param_GA)
|
||||
if args.G_B_ckpt is not None:
|
||||
param_GB = load_checkpoint(args.G_B_ckpt)
|
||||
load_param_into_net(G_B, param_GB)
|
||||
if D_A is not None and args.D_A_ckpt is not None:
|
||||
param_DA = load_checkpoint(args.D_A_ckpt)
|
||||
load_param_into_net(D_A, param_DA)
|
||||
if D_B is not None and args.D_B_ckpt is not None:
|
||||
param_DB = load_checkpoint(args.D_B_ckpt)
|
||||
load_param_into_net(D_B, param_DB)
|
||||
|
||||
|
||||
def load_teacher_ckpt(net, ckpt_path, teacher, student):
|
||||
"""Replace parameter name to teacher net and load parameter from checkpoint."""
|
||||
param = load_checkpoint(ckpt_path)
|
||||
new_param = {}
|
||||
for k, v in param.items():
|
||||
new_name = k.replace(student, teacher)
|
||||
new_param_name = v.name.replace(student, teacher)
|
||||
v.name = new_param_name
|
||||
new_param[new_name] = v
|
||||
load_param_into_net(net, new_param)
|
@ -0,0 +1,74 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN train."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.models import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD, \
|
||||
DiscriminatorLoss, GeneratorLoss
|
||||
from src.utils import get_lr, get_args, Reporter, ImagePool, load_ckpt
|
||||
from src.dataset import create_dataset
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = get_args("train")
|
||||
if args.need_profiler:
|
||||
from mindspore.profiler.profiling import Profiler
|
||||
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
|
||||
ds = create_dataset(args)
|
||||
G_A = get_generator(args)
|
||||
G_B = get_generator(args)
|
||||
D_A = get_discriminator(args)
|
||||
D_B = get_discriminator(args)
|
||||
load_ckpt(args, G_A, G_B, D_A, D_B)
|
||||
imgae_pool_A = ImagePool(args.pool_size)
|
||||
imgae_pool_B = ImagePool(args.pool_size)
|
||||
generator = Generator(G_A, G_B, args.lambda_idt > 0)
|
||||
|
||||
loss_D = DiscriminatorLoss(args, D_A, D_B)
|
||||
loss_G = GeneratorLoss(args, generator, D_A, D_B)
|
||||
optimizer_G = nn.Adam(generator.trainable_params(), get_lr(args), beta1=args.beta1)
|
||||
optimizer_D = nn.Adam(loss_D.trainable_params(), get_lr(args), beta1=args.beta1)
|
||||
|
||||
net_G = TrainOneStepG(loss_G, generator, optimizer_G)
|
||||
net_D = TrainOneStepD(loss_D, optimizer_D)
|
||||
|
||||
data_loader = ds.create_dict_iterator()
|
||||
reporter = Reporter(args)
|
||||
reporter.info('==========start training===============')
|
||||
for _ in range(args.max_epoch):
|
||||
reporter.epoch_start()
|
||||
for data in data_loader:
|
||||
img_A = data["image_A"]
|
||||
img_B = data["image_B"]
|
||||
res_G = net_G(img_A, img_B)
|
||||
fake_A = res_G[0]
|
||||
fake_B = res_G[1]
|
||||
res_D = net_D(img_A, img_B, imgae_pool_A.query(fake_A), imgae_pool_B.query(fake_B))
|
||||
reporter.step_end(res_G, res_D)
|
||||
reporter.visualizer(img_A, img_B, fake_A, fake_B)
|
||||
reporter.epoch_end(net_G)
|
||||
if args.need_profiler:
|
||||
profiler.analyse()
|
||||
break
|
||||
|
||||
reporter.info('==========end training===============')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in new issue