From 38d2d13d8b80dd01ece2ed8af103c385491d1eb5 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 24 Jun 2020 16:29:53 +0800 Subject: [PATCH] [Dy2stat] Add CycleGAN model for unitttest (#25072) * add cycle_gan_model * align train test=develop * modify image_size into 64 to avoid TimeOut test=develop * TODO in GPU test=develop --- .../dygraph_to_static/partial_program.py | 19 + .../dygraph_to_static/test_cycle_gan.py | 616 ++++++++++++++++++ 2 files changed, 635 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 99ebbb9cde..32c36bc381 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -110,11 +110,13 @@ class PartialProgramLayer(layers.Layer): self._inputs = NestSequence(inputs) self._outputs = NestSequence(outputs, need_check=True) self._params = parameters if parameters is not None else [] + # Check all params from main program can be found in self._params: # 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph. # 2. parameter from transformed program shall be found in self._params. # Because they share same data with ParamBase of original dygraph. self._check_params_all_inited(main_program) + self._prune_unused_params(main_program) self._infer_program = main_program self._train_program = self._append_backward_desc() @@ -138,6 +140,23 @@ class PartialProgramLayer(layers.Layer): return program + def _prune_unused_params(self, program): + """ + Prune the parameters not used anywhere in the program. + The `@declarative` may only decorated a sub function which + contains some unused parameters created in `__init__`. + So prune these parameters to avoid unnecessary operations in + `run_program_op`. + """ + required_params = [] + for param in self._params: + for block in program.blocks: + if param.name in block.vars: + required_params.append(param) + break + + self._params = required_params + def train(self): # self.training is inherited from layers.Layer self.training = True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py new file mode 100644 index 0000000000..844438eaf6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py @@ -0,0 +1,616 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import time +import random +import unittest +import numpy as np +from PIL import Image, ImageOps + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator +from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, BatchNorm + +# Note: Set True to eliminate randomness. +# 1. For one operation, cuDNN has several algorithms, +# some algorithm results are non-deterministic, like convolution algorithms. +if fluid.is_compiled_with_cuda(): + fluid.set_flags({'FLAGS_cudnn_deterministic': True}) + +use_cudnn = True +step_per_epoch = 10 +lambda_A = 10.0 +lambda_B = 10.0 +lambda_identity = 0.5 +# TODO(Aurelius84): Modify it into 256 when we move ut into CE platform. +# It will lead to timeout if set 256 in CI. +IMAGE_SIZE = 64 +SEED = 2020 + +program_translator = ProgramTranslator() + + +class Cycle_Gan(fluid.dygraph.Layer): + def __init__(self, input_channel, istrain=True): + super(Cycle_Gan, self).__init__() + + self.build_generator_resnet_9blocks_a = build_generator_resnet_9blocks( + input_channel) + self.build_generator_resnet_9blocks_b = build_generator_resnet_9blocks( + input_channel) + if istrain: + self.build_gen_discriminator_a = build_gen_discriminator( + input_channel) + self.build_gen_discriminator_b = build_gen_discriminator( + input_channel) + + @declarative + def forward(self, input_A, input_B): + """ + Generator of GAN model. + """ + fake_B = self.build_generator_resnet_9blocks_a(input_A) + fake_A = self.build_generator_resnet_9blocks_b(input_B) + cyc_A = self.build_generator_resnet_9blocks_b(fake_B) + cyc_B = self.build_generator_resnet_9blocks_a(fake_A) + + diff_A = fluid.layers.abs( + fluid.layers.elementwise_sub( + x=input_A, y=cyc_A)) + diff_B = fluid.layers.abs( + fluid.layers.elementwise_sub( + x=input_B, y=cyc_B)) + cyc_A_loss = fluid.layers.reduce_mean(diff_A) * lambda_A + cyc_B_loss = fluid.layers.reduce_mean(diff_B) * lambda_B + cyc_loss = cyc_A_loss + cyc_B_loss + + fake_rec_A = self.build_gen_discriminator_a(fake_B) + g_A_loss = fluid.layers.reduce_mean(fluid.layers.square(fake_rec_A - 1)) + + fake_rec_B = self.build_gen_discriminator_b(fake_A) + g_B_loss = fluid.layers.reduce_mean(fluid.layers.square(fake_rec_B - 1)) + G = g_A_loss + g_B_loss + idt_A = self.build_generator_resnet_9blocks_a(input_B) + idt_loss_A = fluid.layers.reduce_mean( + fluid.layers.abs(fluid.layers.elementwise_sub( + x=input_B, y=idt_A))) * lambda_B * lambda_identity + + idt_B = self.build_generator_resnet_9blocks_b(input_A) + idt_loss_B = fluid.layers.reduce_mean( + fluid.layers.abs(fluid.layers.elementwise_sub( + x=input_A, y=idt_B))) * lambda_A * lambda_identity + idt_loss = fluid.layers.elementwise_add(idt_loss_A, idt_loss_B) + g_loss = cyc_loss + G + idt_loss + return fake_A, fake_B, cyc_A, cyc_B, g_A_loss, g_B_loss, idt_loss_A, idt_loss_B, cyc_A_loss, cyc_B_loss, g_loss + + @declarative + def disriminatorA(self, input_A, input_B): + """ + Discriminator A of GAN model. + """ + rec_B = self.build_gen_discriminator_a(input_A) + fake_pool_rec_B = self.build_gen_discriminator_a(input_B) + + return rec_B, fake_pool_rec_B + + @declarative + def discriminatorB(self, input_A, input_B): + """ + Discriminator B of GAN model. + """ + rec_A = self.build_gen_discriminator_b(input_A) + fake_pool_rec_A = self.build_gen_discriminator_b(input_B) + + return rec_A, fake_pool_rec_A + + +class build_resnet_block(fluid.dygraph.Layer): + def __init__(self, dim, use_bias=False): + super(build_resnet_block, self).__init__() + + self.conv0 = conv2d( + num_channels=dim, + num_filters=dim, + filter_size=3, + stride=1, + stddev=0.02, + use_bias=False) + self.conv1 = conv2d( + num_channels=dim, + num_filters=dim, + filter_size=3, + stride=1, + stddev=0.02, + relu=False, + use_bias=False) + self.dim = dim + + def forward(self, inputs): + out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect") + out_res = self.conv0(out_res) + + out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect") + out_res = self.conv1(out_res) + return out_res + inputs + + +class build_generator_resnet_9blocks(fluid.dygraph.Layer): + def __init__(self, input_channel): + super(build_generator_resnet_9blocks, self).__init__() + + self.conv0 = conv2d( + num_channels=input_channel, + num_filters=32, + filter_size=7, + stride=1, + padding=0, + stddev=0.02) + self.conv1 = conv2d( + num_channels=32, + num_filters=64, + filter_size=3, + stride=2, + padding=1, + stddev=0.02) + self.conv2 = conv2d( + num_channels=64, + num_filters=128, + filter_size=3, + stride=2, + padding=1, + stddev=0.02) + self.build_resnet_block_list = [] + dim = 128 + for i in range(9): + Build_Resnet_Block = self.add_sublayer("generator_%d" % (i + 1), + build_resnet_block(dim)) + self.build_resnet_block_list.append(Build_Resnet_Block) + self.deconv0 = DeConv2D( + num_channels=dim, + num_filters=32 * 2, + filter_size=3, + stride=2, + stddev=0.02, + padding=[1, 1], + outpadding=[0, 1, 0, 1], ) + self.deconv1 = DeConv2D( + num_channels=32 * 2, + num_filters=32, + filter_size=3, + stride=2, + stddev=0.02, + padding=[1, 1], + outpadding=[0, 1, 0, 1]) + self.conv3 = conv2d( + num_channels=32, + num_filters=input_channel, + filter_size=7, + stride=1, + stddev=0.02, + padding=0, + relu=False, + norm=False, + use_bias=True) + + def forward(self, inputs): + pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect") + y = self.conv0(pad_input) + y = self.conv1(y) + y = self.conv2(y) + for build_resnet_block_i in self.build_resnet_block_list: + y = build_resnet_block_i(y) + y = self.deconv0(y) + y = self.deconv1(y) + y = fluid.layers.pad2d(y, [3, 3, 3, 3], mode="reflect") + y = self.conv3(y) + y = fluid.layers.tanh(y) + return y + + +class build_gen_discriminator(fluid.dygraph.Layer): + def __init__(self, input_channel): + super(build_gen_discriminator, self).__init__() + + self.conv0 = conv2d( + num_channels=input_channel, + num_filters=64, + filter_size=4, + stride=2, + stddev=0.02, + padding=1, + norm=False, + use_bias=True, + relufactor=0.2) + self.conv1 = conv2d( + num_channels=64, + num_filters=128, + filter_size=4, + stride=2, + stddev=0.02, + padding=1, + relufactor=0.2) + self.conv2 = conv2d( + num_channels=128, + num_filters=IMAGE_SIZE, + filter_size=4, + stride=2, + stddev=0.02, + padding=1, + relufactor=0.2) + self.conv3 = conv2d( + num_channels=IMAGE_SIZE, + num_filters=512, + filter_size=4, + stride=1, + stddev=0.02, + padding=1, + relufactor=0.2) + self.conv4 = conv2d( + num_channels=512, + num_filters=1, + filter_size=4, + stride=1, + stddev=0.02, + padding=1, + norm=False, + relu=False, + use_bias=True) + + def forward(self, inputs): + y = self.conv0(inputs) + y = self.conv1(y) + y = self.conv2(y) + y = self.conv3(y) + y = self.conv4(y) + return y + + +class conv2d(fluid.dygraph.Layer): + """docstring for Conv2D""" + + def __init__(self, + num_channels, + num_filters=64, + filter_size=7, + stride=1, + stddev=0.02, + padding=0, + norm=True, + relu=True, + relufactor=0.0, + use_bias=False): + super(conv2d, self).__init__() + + if use_bias == False: + con_bias_attr = False + else: + con_bias_attr = fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0)) + + self.conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + use_cudnn=use_cudnn, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=stddev)), + bias_attr=con_bias_attr) + if norm: + self.bn = BatchNorm( + num_channels=num_filters, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.NormalInitializer(1.0, 0.02)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0)), + trainable_statistics=True) + + self.relufactor = relufactor + self.use_bias = use_bias + self.norm = norm + self.relu = relu + + def forward(self, inputs): + conv = self.conv(inputs) + if self.norm: + conv = self.bn(conv) + if self.relu: + conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor) + return conv + + +class DeConv2D(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters=64, + filter_size=7, + stride=1, + stddev=0.02, + padding=[0, 0], + outpadding=[0, 0, 0, 0], + relu=True, + norm=True, + relufactor=0.0, + use_bias=False): + super(DeConv2D, self).__init__() + + if use_bias == False: + de_bias_attr = False + else: + de_bias_attr = fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0)) + + self._deconv = Conv2DTranspose( + num_channels, + num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + use_cudnn=use_cudnn, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=stddev)), + bias_attr=de_bias_attr) + if norm: + self.bn = BatchNorm( + num_channels=num_filters, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.NormalInitializer(1.0, 0.02)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0)), + trainable_statistics=True) + + self.outpadding = outpadding + self.relufactor = relufactor + self.use_bias = use_bias + self.norm = norm + self.relu = relu + + def forward(self, inputs): + conv = self._deconv(inputs) + conv = fluid.layers.pad2d( + conv, paddings=self.outpadding, mode='constant', pad_value=0.0) + + if self.norm: + conv = self.bn(conv) + if self.relu: + conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor) + return conv + + +class ImagePool(object): + def __init__(self, pool_size=50): + self.pool = [] + self.count = 0 + self.pool_size = pool_size + + def pool_image(self, image): + if self.count < self.pool_size: + self.pool.append(image) + self.count += 1 + return image + else: + p = np.random.rand() + if p > 0.5: + random_id = np.random.randint(0, self.pool_size - 1) + temp = self.pool[random_id] + self.pool[random_id] = image + return temp + else: + return image + + +def reader_creater(): + # local_random = np.random.RandomState(SEED) + def reader(): + while True: + fake_image = np.uint8( + np.random.random((IMAGE_SIZE + 30, IMAGE_SIZE + 30, 3)) * 255) + image = Image.fromarray(fake_image) + # Resize + image = image.resize((286, 286), Image.BICUBIC) + # RandomCrop + i = np.random.randint(0, 30) + j = np.random.randint(0, 30) + image = image.crop((i, j, i + IMAGE_SIZE, j + IMAGE_SIZE)) + # RandomHorizontalFlip + sed = np.random.rand() + if sed > 0.5: + image = ImageOps.mirror(image) + # ToTensor + image = np.array(image).transpose([2, 0, 1]).astype('float32') + image = image / 255.0 + # Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5] + image = (image - 0.5) / 0.5 + + yield image + + return reader + + +class Args(object): + epoch = 1 + batch_size = 4 + image_shape = [3, IMAGE_SIZE, IMAGE_SIZE] + max_images_num = step_per_epoch + log_step = 1 + train_step = 3 + + +def optimizer_setting(parameters): + lr = 0.0002 + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[ + 100 * step_per_epoch, 120 * step_per_epoch, + 140 * step_per_epoch, 160 * step_per_epoch, 180 * step_per_epoch + ], + values=[lr, lr * 0.8, lr * 0.6, lr * 0.4, lr * 0.2, lr * 0.1]), + parameter_list=parameters, + beta1=0.5) + return optimizer + + +def train(args, to_static): + # FIXME(Aurelius84): Found diff just on GPU and it disappears when we remove the BatchNorm layers. + # In dygraph mode, it still exists with different output while executing the every time. + + # place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \ + # else fluid.CPUPlace() + + place = fluid.CPUPlace() + + program_translator.enable(to_static) + + with fluid.dygraph.guard(place): + max_images_num = args.max_images_num + data_shape = [-1] + args.image_shape + + random.seed(SEED) + np.random.seed(SEED) + fluid.default_startup_program().random_seed = SEED + fluid.default_main_program().random_seed = SEED + + A_pool = ImagePool() + B_pool = ImagePool() + A_reader = paddle.batch(reader_creater(), args.batch_size)() + B_reader = paddle.batch(reader_creater(), args.batch_size)() + cycle_gan = Cycle_Gan(input_channel=data_shape[1], istrain=True) + + t_time = 0 + vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters( + ) + cycle_gan.build_generator_resnet_9blocks_b.parameters() + vars_da = cycle_gan.build_gen_discriminator_a.parameters() + vars_db = cycle_gan.build_gen_discriminator_b.parameters() + + optimizer1 = optimizer_setting(vars_G) + optimizer2 = optimizer_setting(vars_da) + optimizer3 = optimizer_setting(vars_db) + + loss_data = [] + for epoch in range(args.epoch): + for batch_id in range(max_images_num): + + data_A = next(A_reader) + data_B = next(B_reader) + + s_time = time.time() + data_A = np.array( + [data_A[0].reshape(3, IMAGE_SIZE, IMAGE_SIZE)]).astype( + "float32") + data_B = np.array( + [data_B[0].reshape(3, IMAGE_SIZE, IMAGE_SIZE)]).astype( + "float32") + data_A = to_variable(data_A) + data_B = to_variable(data_B) + + # optimize the g_A network + fake_A, fake_B, cyc_A, cyc_B, g_A_loss, g_B_loss, idt_loss_A, idt_loss_B, cyc_A_loss, cyc_B_loss, g_loss = cycle_gan( + data_A, data_B) + + g_loss.backward() + optimizer1.minimize(g_loss) + cycle_gan.clear_gradients() + + fake_pool_B = B_pool.pool_image(fake_B).numpy() + fake_pool_B = np.array( + [fake_pool_B[0].reshape(3, IMAGE_SIZE, IMAGE_SIZE)]).astype( + "float32") + fake_pool_B = to_variable(fake_pool_B) + + fake_pool_A = A_pool.pool_image(fake_A).numpy() + fake_pool_A = np.array( + [fake_pool_A[0].reshape(3, IMAGE_SIZE, IMAGE_SIZE)]).astype( + "float32") + fake_pool_A = to_variable(fake_pool_A) + + # optimize the d_A network + rec_B, fake_pool_rec_B = cycle_gan.disriminatorA(data_B, + fake_pool_B) + d_loss_A = (fluid.layers.square(fake_pool_rec_B) + + fluid.layers.square(rec_B - 1)) / 2.0 + d_loss_A = fluid.layers.reduce_mean(d_loss_A) + + d_loss_A.backward() + optimizer2.minimize(d_loss_A) + cycle_gan.clear_gradients() + + # optimize the d_B network + rec_A, fake_pool_rec_A = cycle_gan.discriminatorB(data_A, + fake_pool_A) + d_loss_B = (fluid.layers.square(fake_pool_rec_A) + + fluid.layers.square(rec_A - 1)) / 2.0 + d_loss_B = fluid.layers.reduce_mean(d_loss_B) + + d_loss_B.backward() + optimizer3.minimize(d_loss_B) + + cycle_gan.clear_gradients() + + # Log generator loss and discriminator loss + cur_batch_loss = [ + g_loss, d_loss_A, d_loss_B, g_A_loss, cyc_A_loss, + idt_loss_A, g_B_loss, cyc_B_loss, idt_loss_B + ] + cur_batch_loss = [x.numpy()[0] for x in cur_batch_loss] + loss_data.append(cur_batch_loss) + + batch_time = time.time() - s_time + t_time += batch_time + if batch_id % args.log_step == 0: + print( + "batch: {}\t Batch_time_cost: {}\n g_loss: {}\t d_A_loss: {}\t d_B_loss:{}\n g_A_loss: {}\t g_A_cyc_loss: {}\t g_A_idt_loss: {}\n g_B_loss: {}\t g_B_cyc_loss: {}\t g_B_idt_loss: {}". + format(batch_id, batch_time, *cur_batch_loss)) + + if batch_id > args.train_step: + break + + return np.array(loss_data) + + +class TestCycleGANModel(unittest.TestCase): + def setUp(self): + self.args = Args() + + def train(self, to_static): + out = train(self.args, to_static) + return out + + def test_train(self): + st_out = self.train(to_static=True) + dy_out = self.train(to_static=False) + self.assertTrue( + np.allclose(dy_out, st_out), + msg="dy_out:\n {}\n st_out:\n{}".format(dy_out, st_out)) + + +if __name__ == "__main__": + unittest.main()