From fce3c2d3b20cbc1b7f3dfa5dc7c3c245b6267dca Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Fri, 26 Mar 2021 12:42:08 +0800 Subject: [PATCH] fix faceattribute network bug --- model_zoo/research/cv/FaceAttribute/train.py | 24 +++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/model_zoo/research/cv/FaceAttribute/train.py b/model_zoo/research/cv/FaceAttribute/train.py index d38f0d17f3..93db06cbc4 100644 --- a/model_zoo/research/cv/FaceAttribute/train.py +++ b/model_zoo/research/cv/FaceAttribute/train.py @@ -18,6 +18,7 @@ import time import datetime import argparse +import mindspore import mindspore.nn as nn from mindspore import context from mindspore import Tensor @@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs class BuildTrainNetwork(nn.Cell): '''Build train network.''' - def __init__(self, network, criterion): + def __init__(self, my_network, my_criterion): super(BuildTrainNetwork, self).__init__() - self.network = network - self.criterion = criterion + self.network = my_network + self.criterion = my_criterion self.print = P.Print() def construct(self, input_data, label): logit0, logit1, logit2 = self.network(input_data) - loss = self.criterion(logit0, logit1, logit2, label) - return loss + loss0 = self.criterion(logit0, logit1, logit2, label) + return loss0 def parse_args(): @@ -64,13 +65,14 @@ def parse_args(): parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') - args, _ = parser.parse_known_args() + arg, _ = parser.parse_known_args() - return args + return arg -def train(): - '''train function.''' +if __name__ == "__main__": + mindspore.set_seed(1) + # logger args = parse_args() @@ -226,7 +228,3 @@ def train(): i += 1 args.logger.info('--------- trains out ---------') - - -if __name__ == "__main__": - train()