!13591 use same network in the TrainOneStepCell

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/13591/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f3453ef22c

@ -46,8 +46,6 @@ class LossCallBack(Callback):
self._per_print_times = per_print_times self._per_print_times = per_print_times
self.count = 0 self.count = 0
self.rpn_loss_sum = 0 self.rpn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rank_id = rank_id self.rank_id = rank_id
global time_stamp_init, time_stamp_first global time_stamp_init, time_stamp_first
@ -57,14 +55,10 @@ class LossCallBack(Callback):
def step_end(self, run_context): def step_end(self, run_context):
cb_params = run_context.original_args() cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy() rpn_loss = cb_params.net_outputs.asnumpy()
rpn_cls_loss = cb_params.net_outputs[1].asnumpy()
rpn_reg_loss = cb_params.net_outputs[2].asnumpy()
self.count += 1 self.count += 1
self.rpn_loss_sum += float(rpn_loss) self.rpn_loss_sum += float(rpn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
@ -72,12 +66,10 @@ class LossCallBack(Callback):
global time_stamp_first global time_stamp_first
time_stamp_current = time.time() time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum / self.count rpn_loss = self.rpn_loss_sum / self.count
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"% loss_file.write("%lu epoch: %s step: %s rpn_loss: %.5f"%
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rpn_cls_loss, rpn_reg_loss)) rpn_loss))
loss_file.write("\n") loss_file.write("\n")
loss_file.close() loss_file.close()
@ -123,18 +115,16 @@ class TrainOneStepCell(nn.Cell):
Args: Args:
network (Cell): The training network. network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights. optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0. sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False. reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False. mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None. degree (int): Device number. Default value is None.
""" """
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad() self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
@ -146,8 +136,8 @@ class TrainOneStepCell(nn.Cell):
def construct(self, x, gt_bbox, gt_label, gt_num, img_shape=None): def construct(self, x, gt_bbox, gt_label, gt_num, img_shape=None):
weights = self.weights weights = self.weights
rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, gt_bbox, gt_label, gt_num, img_shape) loss = self.network(x, gt_bbox, gt_label, gt_num, img_shape)
grads = self.grad(self.network, weights)(x, gt_bbox, gt_label, gt_num, img_shape, self.sens) grads = self.grad(self.network, weights)(x, gt_bbox, gt_label, gt_num, img_shape, self.sens)
if self.reduce_flag: if self.reduce_flag:
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss return F.depend(loss, self.optimizer(grads))

@ -100,10 +100,10 @@ if __name__ == '__main__':
weight_decay=config.weight_decay, loss_scale=config.loss_scale) weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute: if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num) mean=True, degree=device_num)
else: else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size) time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank) loss_cb = LossCallBack(rank_id=rank)

@ -69,7 +69,7 @@ class LossCallBack(Callback):
total_loss = self.loss_sum / self.count total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
total_loss)) total_loss))
loss_file.write("\n") loss_file.write("\n")

@ -69,7 +69,7 @@ class LossCallBack(Callback):
total_loss = self.loss_sum / self.count total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
total_loss)) total_loss))
loss_file.write("\n") loss_file.write("\n")

@ -68,7 +68,7 @@ class LossCallBack(Callback):
total_loss = self.loss_sum / self.count total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
total_loss)) total_loss))
loss_file.write("\n") loss_file.write("\n")

@ -97,7 +97,7 @@ class LossCallBack(Callback):
total_loss = self.loss_sum/self.count total_loss = self.loss_sum/self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+") loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
total_loss)) total_loss))
loss_file.write("\n") loss_file.write("\n")

Loading…
Cancel
Save