|
|
|
@ -18,6 +18,7 @@ import time
|
|
|
|
|
import datetime
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
import mindspore
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs
|
|
|
|
|
|
|
|
|
|
class BuildTrainNetwork(nn.Cell):
|
|
|
|
|
'''Build train network.'''
|
|
|
|
|
def __init__(self, network, criterion):
|
|
|
|
|
def __init__(self, my_network, my_criterion):
|
|
|
|
|
super(BuildTrainNetwork, self).__init__()
|
|
|
|
|
self.network = network
|
|
|
|
|
self.criterion = criterion
|
|
|
|
|
self.network = my_network
|
|
|
|
|
self.criterion = my_criterion
|
|
|
|
|
self.print = P.Print()
|
|
|
|
|
|
|
|
|
|
def construct(self, input_data, label):
|
|
|
|
|
logit0, logit1, logit2 = self.network(input_data)
|
|
|
|
|
loss = self.criterion(logit0, logit1, logit2, label)
|
|
|
|
|
return loss
|
|
|
|
|
loss0 = self.criterion(logit0, logit1, logit2, label)
|
|
|
|
|
return loss0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
@ -64,13 +65,14 @@ def parse_args():
|
|
|
|
|
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
|
|
|
|
|
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
|
|
|
|
|
|
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
|
arg, _ = parser.parse_known_args()
|
|
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
return arg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train():
|
|
|
|
|
'''train function.'''
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
mindspore.set_seed(1)
|
|
|
|
|
|
|
|
|
|
# logger
|
|
|
|
|
args = parse_args()
|
|
|
|
|
|
|
|
|
@ -226,7 +228,3 @@ def train():
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
args.logger.info('--------- trains out ---------')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
train()
|
|
|
|
|