|
|
@ -218,8 +218,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|
|
|
self.network = network
|
|
|
|
self.network = network
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.grad = C.GradOperation('grad',
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
get_by_list=True,
|
|
|
|
|
|
|
|
sens_param=True)
|
|
|
|
sens_param=True)
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
@ -310,8 +309,7 @@ class BertTrainCell(nn.Cell):
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.sens = sens
|
|
|
|
self.sens = sens
|
|
|
|
self.grad = C.GradOperation('grad',
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
get_by_list=True,
|
|
|
|
|
|
|
|
sens_param=True)
|
|
|
|
sens_param=True)
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
@ -474,8 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
|
|
|
|
self.network = network
|
|
|
|
self.network = network
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.grad = C.GradOperation('grad',
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
get_by_list=True,
|
|
|
|
|
|
|
|
sens_param=True)
|
|
|
|
sens_param=True)
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
|
self.allreduce = P.AllReduce()
|
|
|
@ -562,8 +559,7 @@ class BertEvaluationCell(nn.Cell):
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.sens = sens
|
|
|
|
self.sens = sens
|
|
|
|
self.grad = C.GradOperation('grad',
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
get_by_list=True,
|
|
|
|
|
|
|
|
sens_param=True)
|
|
|
|
sens_param=True)
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.reducer_flag = False
|
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|