!14171 Fix FaceAttribute net bug

From: @zhanghuiyao
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
pull/14171/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 44d6da63c2

@ -18,6 +18,7 @@ import time
import datetime import datetime
import argparse import argparse
import mindspore
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs
class BuildTrainNetwork(nn.Cell): class BuildTrainNetwork(nn.Cell):
'''Build train network.''' '''Build train network.'''
def __init__(self, network, criterion): def __init__(self, my_network, my_criterion):
super(BuildTrainNetwork, self).__init__() super(BuildTrainNetwork, self).__init__()
self.network = network self.network = my_network
self.criterion = criterion self.criterion = my_criterion
self.print = P.Print() self.print = P.Print()
def construct(self, input_data, label): def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data) logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label) loss0 = self.criterion(logit0, logit1, logit2, label)
return loss return loss0
def parse_args(): def parse_args():
@ -64,13 +65,14 @@ def parse_args():
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
args, _ = parser.parse_known_args() arg, _ = parser.parse_known_args()
return args return arg
def train(): if __name__ == "__main__":
'''train function.''' mindspore.set_seed(1)
# logger # logger
args = parse_args() args = parse_args()
@ -226,7 +228,3 @@ def train():
i += 1 i += 1
args.logger.info('--------- trains out ---------') args.logger.info('--------- trains out ---------')
if __name__ == "__main__":
train()

Loading…
Cancel
Save